Skip to content

Commit

Permalink
fix: bleh
Browse files Browse the repository at this point in the history
  • Loading branch information
Lulalaby committed May 2, 2024
1 parent 047e466 commit 5dec1c6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public async Task<IReadOnlyList<AuthorizationCodeExchangeEventArgs>> CollectMatc
this._collectRequests.TryRemove(request);
}

return request.Collected.ToArray();
return [.. request.Collected];
}

/// <summary>
Expand Down
10 changes: 10 additions & 0 deletions DisCatSharp.Extensions.OAuth2Web/ExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ public static async Task StopAsync(this IReadOnlyDictionary<int, OAuth2WebExtens
public static bool HasAllRequiredRedirectUrisSet(this IReadOnlyDictionary<int, OAuth2WebExtension> extensions, DiscordShardedClient client)
=> extensions.Values.Select(extension => extension.Configuration.RedirectUri).All(client.CurrentApplication.RedirectUris.Contains);

/// <summary>
/// <para>Checks if the redirect uri is set for the application in the developer portal.</para>
/// <para>Use this function after you've executed <see cref="DiscordClient.ConnectAsync"/>.</para>
/// </summary>
/// <param name="extensions">The extensions.</param>
/// <param name="client">The <see cref="DiscordClient"/>.</param>
/// <returns>Whether the required redirect uris is set.</returns>
public static bool HasRequiredRedirectUriSet(this OAuth2WebExtension extensions, DiscordClient client)
=> client.CurrentApplication.RedirectUris.Contains(extensions.Configuration.RedirectUri);

/// <summary>
/// Gets the required redirect uris for the developer portal.
/// </summary>
Expand Down
38 changes: 23 additions & 15 deletions DisCatSharp.Extensions.OAuth2Web/OAuth2WebConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,30 @@ public sealed class OAuth2WebConfiguration
/// </summary>
public string? HtmlOutputException { internal get; init; } = null;

/// <summary>
/// <para>Sets the minimum logging level for messages.</para>
/// <para>Defaults to <see cref="LogLevel.Information"/>.</para>
/// </summary>
public LogLevel MinimumLogLevel { internal get; set; } = LogLevel.Information;
/// <summary>
/// <para>Sets the minimum logging level for messages.</para>
/// <para>Defaults to <see cref="LogLevel.Information"/>.</para>
/// </summary>
public LogLevel MinimumLogLevel { internal get; set; } = LogLevel.Information;

/// <summary>
/// <para>Allows you to overwrite the time format used by the internal debug logger.</para>
/// <para>Only applicable when <see cref="LoggerFactory"/> is set left at default value. Defaults to ISO 8601-like format.</para>
/// </summary>
public string LogTimestampFormat { internal get; set; } = "yyyy-MM-dd HH:mm:ss zzz";
/// <summary>
/// <para>Allows you to overwrite the time format used by the internal debug logger.</para>
/// <para>Only applicable when <see cref="LoggerFactory"/> is set left at default value. Defaults to ISO 8601-like format.</para>
/// </summary>
public string LogTimestampFormat { internal get; set; } = "yyyy-MM-dd HH:mm:ss zzz";

/// <summary>
/// <para>Sets the proxy to use for HTTP connections to Discord.</para>
/// <para>Defaults to <see langword="null"/>.</para>
/// </summary>
public IWebProxy? Proxy { internal get; set; } = null;
/// <summary>
/// <para>Sets the proxy to use for HTTP connections to Discord.</para>
/// <para>Defaults to <see langword="null"/>.</para>
/// </summary>
public IWebProxy? Proxy { internal get; set; } = null;

/// <summary>
/// <para>Sets the logger implementation to use.</para>
/// <para>To create your own logger, implement the <see cref="Microsoft.Extensions.Logging.ILoggerFactory"/> instance.</para>
/// <para>Defaults to built-in implementation.</para>
/// </summary>
public ILoggerFactory? LoggerFactory { internal get; set; } = null;

/// <summary>
/// Creates a new instance of <see cref="OAuth2WebConfiguration"/>.
Expand Down Expand Up @@ -174,5 +181,6 @@ public OAuth2WebConfiguration(OAuth2WebConfiguration other)
this.Proxy = other.Proxy;
this.LogTimestampFormat = other.LogTimestampFormat;
this.MinimumLogLevel = other.MinimumLogLevel;
this.LoggerFactory = other.LoggerFactory;
}
}
27 changes: 19 additions & 8 deletions DisCatSharp.Extensions.OAuth2Web/OAuth2WebExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

namespace DisCatSharp.Extensions.OAuth2Web;

Expand All @@ -49,6 +50,11 @@ namespace DisCatSharp.Extensions.OAuth2Web;
/// </summary>
public sealed class OAuth2WebExtension : BaseExtension
{
/// <summary>
/// Gets the logger for this extension.
/// </summary>
public ILogger<OAuth2WebExtension> Logger { get; private set; }

/// <summary>
/// Gets the OAuth2 Web configuration.
/// </summary>
Expand All @@ -67,8 +73,7 @@ public sealed class OAuth2WebExtension : BaseExtension
/// <summary>
/// Gets the service provider this OAuth2 Web module was configured with.
/// </summary>
public IServiceProvider ServiceProvider
=> this.Configuration.ServiceProvider;
public IServiceProvider ServiceProvider { get; private set; }

/// <summary>
/// Gets the authorization code event waiter.
Expand Down Expand Up @@ -149,14 +154,18 @@ internal OAuth2WebExtension(OAuth2WebConfiguration configuration) // , DiscordCl
{
this.Configuration = configuration;

this.OAuth2Client = new(this.Configuration.ClientId, this.Configuration.ClientSecret, this.Configuration.RedirectUri, this.ServiceProvider, this.Configuration.Proxy, null, true, null, this.Configuration.MinimumLogLevel, this.Configuration.LogTimestampFormat); // , discordClient: discordClient);
this.OAuth2Client = new(this.Configuration.ClientId, this.Configuration.ClientSecret, this.Configuration.RedirectUri, this.ServiceProvider, this.Configuration.Proxy, default, default, this.Configuration.LoggerFactory, this.Configuration.MinimumLogLevel, this.Configuration.LogTimestampFormat); // , discordClient: discordClient);

this._authorizationCodeReceived = new("OAUTH2_AUTH_CODE_RECEIVED", TimeSpan.Zero, this.OAuth2Client.EventErrorHandler);
this._authorizationCodeExchanged = new("OAUTH2_AUTH_CODE_EXCHANGED", TimeSpan.Zero, this.OAuth2Client.EventErrorHandler);
this._accessTokenRefreshed = new("OAUTH2_ACCESS_TOKEN_REFRESHED", TimeSpan.Zero, this.OAuth2Client.EventErrorHandler);
this._accessTokenRevoked = new("OAUTH2_ACCESS_TOKEN_REVOKED", TimeSpan.Zero, this.OAuth2Client.EventErrorHandler);
this._authorizationCodeWaiter = new(this, this.OAuth2Client);

this.AuthorizationCodeExchanged += this.OnAuthorizationCodeExchangedAsync;
this.AccessTokenRefreshed += this.OnAccessTokenRefreshedAsync;
this.AccessTokenRevoked += this.OnAccessTokenRevokedAsync;

var builder = WebApplication.CreateBuilder();

builder.Services.AddRouting();
Expand All @@ -172,12 +181,8 @@ internal OAuth2WebExtension(OAuth2WebConfiguration configuration) // , DiscordCl

this.WEB_APP.UseAuthorization();

this.WEB_APP.MapGet("/oauth/{shard}", this.HandleOAuth2Async);
this.WEB_APP.MapGet("/oauth/{shard}/", this.HandleOAuth2Async);
this.WEB_APP.MapGet("/oauth/", this.HandleOAuth2Async);

this.AuthorizationCodeExchanged += this.OnAuthorizationCodeExchangedAsync;
this.AccessTokenRefreshed += this.OnAccessTokenRefreshedAsync;
this.AccessTokenRevoked += this.OnAccessTokenRevokedAsync;
}

/// <summary>
Expand Down Expand Up @@ -353,6 +358,9 @@ protected internal override void Setup(DiscordClient client)
this.Repository = "DisCatSharp.Extensions";

this.PackageId = "DisCatSharp.Extensions.OAuth2Web";

this.Logger = (this.Configuration.LoggerFactory ?? this.Client.Configuration.LoggerFactory).CreateLogger<OAuth2WebExtension>();
this.ServiceProvider = this.Configuration.ServiceProvider;
}

/// <summary>
Expand Down Expand Up @@ -491,7 +499,10 @@ await this._authorizationCodeReceived.InvokeAsync(this.OAuth2Client,
{
var stateUserId = ulong.Parse(this.OAuth2Client.ReadSecureState(state).Split("::")[1]);
if (stateUserId != info.User?.Id)
{
this.Logger.LogCritical("OAuth2Web::SecurityException - Received user id {receivedUserId} does not matches authorized user id {authorizedUserId} or authorized is null.", stateUserId, info.User?.Id);
throw new SecurityException("State mismatch");
}
}

var targetPending = this.OAuth2RequestUrls.First(u => this.OAuth2Client.ValidateState(new(u), requestUrl, this.Configuration.SecureStates));
Expand Down

0 comments on commit 5dec1c6

Please sign in to comment.