Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement cooldowns for application commands #431

Merged
merged 16 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 48 additions & 26 deletions DisCatSharp.ApplicationCommands/ApplicationCommandsExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
using DisCatSharp.ApplicationCommands.EventArgs;
using DisCatSharp.ApplicationCommands.Exceptions;
using DisCatSharp.ApplicationCommands.Workers;
using DisCatSharp.Attributes;
using DisCatSharp.Common;
using DisCatSharp.Common.Utilities;
using DisCatSharp.Entities;
using DisCatSharp.Enums;
using DisCatSharp.Enums.Core;
using DisCatSharp.EventArgs;
using DisCatSharp.Exceptions;

Expand Down Expand Up @@ -661,39 +661,37 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
if (Configuration.GenerateTranslationFilesOnly)
{
var cgwsgs = new List<CommandGroupWithSubGroups>();
var cgs2 = new List<CommandGroup>();
foreach (var cmd in slashGroupsTuple.applicationCommands)
if (cmd.Type is ApplicationCommandType.ChatInput)
{
var cgs = new List<CommandGroup>();
var cs2 = new List<Command>();
if (cmd.Options is not null)
{
foreach (var scg in cmd.Options.Where(x => x.Type is ApplicationCommandOptionType.SubCommandGroup))
{
var cs = new List<Command>();
if (scg.Options is not null)
foreach (var sc in scg.Options)
if (sc.Options is null || sc.Options.Count is 0)
cs.Add(new(sc.Name, sc.Description, null, null));
cs.Add(new(sc.Name, sc.Description, null, null, sc.RawNameLocalizations, sc.RawDescriptionLocalizations));
else
cs.Add(new(sc.Name, sc.Description, [.. sc.Options], null));
cgs.Add(new(scg.Name, scg.Description, cs, null));
cs.Add(new(sc.Name, sc.Description, [.. sc.Options], null, sc.RawNameLocalizations, sc.RawDescriptionLocalizations));
cgs.Add(new(scg.Name, scg.Description, cs, null, scg.RawNameLocalizations, scg.RawDescriptionLocalizations));
}

cgwsgs.Add(new(cmd.Name, cmd.Description, cgs, cmd.Type));
foreach (var sc2 in cmd.Options.Where(x => x.Type is ApplicationCommandOptionType.SubCommand))
if (sc2.Options == null || sc2.Options.Count == 0)
cs2.Add(new(sc2.Name, sc2.Description, null, null, sc2.RawNameLocalizations, sc2.RawDescriptionLocalizations));
else
cs2.Add(new(sc2.Name, sc2.Description, [.. sc2.Options], null, sc2.RawNameLocalizations, sc2.RawDescriptionLocalizations));
}

var cs2 = new List<Command>();
foreach (var sc2 in cmd.Options.Where(x => x.Type is ApplicationCommandOptionType.SubCommand))
if (sc2.Options == null || sc2.Options.Count == 0)
cs2.Add(new(sc2.Name, sc2.Description, null, null));
else
cs2.Add(new(sc2.Name, sc2.Description, [.. sc2.Options], null));
cgs2.Add(new(cmd.Name, cmd.Description, cs2, cmd.Type));
cgwsgs.Add(new(cmd.Name, cmd.Description, cgs, cs2, cmd.Type, cmd.RawNameLocalizations, cmd.RawDescriptionLocalizations));
}

if (cgwsgs.Count is not 0)
groupTranslation.AddRange(cgwsgs.Select(cgwsg => JsonConvert.DeserializeObject<GroupTranslator>(JsonConvert.SerializeObject(cgwsg))!));
if (cgs2.Count is not 0)
groupTranslation.AddRange(cgs2.Select(cg2 => JsonConvert.DeserializeObject<GroupTranslator>(JsonConvert.SerializeObject(cg2))!));
}
}

Expand Down Expand Up @@ -733,12 +731,20 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
var cs = new List<Command>();
foreach (var cmd in slashCommands.applicationCommands.Where(cmd => cmd.Type is ApplicationCommandType.ChatInput && (cmd.Options is null || !cmd.Options.Any(x => x.Type is ApplicationCommandOptionType.SubCommand or ApplicationCommandOptionType.SubCommandGroup))))
if (cmd.Options == null || cmd.Options.Count == 0)
cs.Add(new(cmd.Name, cmd.Description, null, ApplicationCommandType.ChatInput));
cs.Add(new(cmd.Name, cmd.Description, null, ApplicationCommandType.ChatInput, cmd.RawNameLocalizations, cmd.RawDescriptionLocalizations));
else
cs.Add(new(cmd.Name, cmd.Description, [.. cmd.Options], ApplicationCommandType.ChatInput));
cs.Add(new(cmd.Name, cmd.Description, [.. cmd.Options], ApplicationCommandType.ChatInput, cmd.RawNameLocalizations, cmd.RawDescriptionLocalizations));

if (cs.Count is not 0)
translation.AddRange(cs.Select(c => JsonConvert.DeserializeObject<CommandTranslator>(JsonConvert.SerializeObject(c))!));
//translation.AddRange(cs.Select(c => JsonConvert.DeserializeObject<CommandTranslator>(JsonConvert.SerializeObject(c))!));
{
foreach (var c in cs)
{
var json = JsonConvert.SerializeObject(c);
var obj = JsonConvert.DeserializeObject<CommandTranslator>(json);
translation.Add(obj!);
}
}
}
}

Expand Down Expand Up @@ -804,7 +810,7 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
{
updateList = updateList.DistinctBy(x => x.Name).ToList();
if (Configuration.GenerateTranslationFilesOnly)
await this.CheckRegistrationStartup(translation, groupTranslation);
await this.CheckRegistrationStartup(translation, groupTranslation, guildId);
else
try
{
Expand Down Expand Up @@ -911,7 +917,7 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
RegisteredCommands = GlobalCommandsInternal
}).ConfigureAwait(false);

await this.CheckRegistrationStartup(translation, groupTranslation);
await this.CheckRegistrationStartup(translation, groupTranslation, guildId);
}
catch (NullReferenceException ex)
{
Expand Down Expand Up @@ -965,15 +971,16 @@ private async Task RegisterCommands(List<ApplicationCommandsModuleConfiguration>
/// </summary>
/// <param name="translation">The optional translations.</param>
/// <param name="groupTranslation">The optional group translations.</param>
private async Task CheckRegistrationStartup(List<CommandTranslator>? translation = null, List<GroupTranslator>? groupTranslation = null)
/// <param name="guildId">The optional guild id.</param>
private async Task CheckRegistrationStartup(List<CommandTranslator>? translation = null, List<GroupTranslator>? groupTranslation = null, ulong? guildId = null)
{
if (Configuration.GenerateTranslationFilesOnly)
{
try
{
if (translation is not null && translation.Count is not 0)
{
var fileName = $"translation_generator_export-shard{this.Client.ShardId}-SINGLE.json";
var fileName = $"translation_generator_export-shard{this.Client.ShardId}-SINGLE-{(guildId.HasValue ? guildId.Value : "global")}.json";
var fs = File.Create(fileName);
var ms = new MemoryStream();
var writer = new StreamWriter(ms);
Expand All @@ -991,7 +998,7 @@ private async Task CheckRegistrationStartup(List<CommandTranslator>? translation

if (groupTranslation is not null && groupTranslation.Count is not 0)
{
var fileName = $"translation_generator_export-shard{this.Client.ShardId}-GROUP.json";
var fileName = $"translation_generator_export-shard{this.Client.ShardId}-GROUP-{(guildId.HasValue ? guildId.Value : "global")}.json";
var fs = File.Create(fileName);
var ms = new MemoryStream();
var writer = new StreamWriter(ms);
Expand Down Expand Up @@ -1030,6 +1037,8 @@ private async Task CheckStartupFinishAsync(ApplicationCommandsExtension sender,
GuildsWithoutScope = s_missingScopeGuildIdsGlobal
}).ConfigureAwait(false);
FinishFired = true;
if (Configuration.GenerateTranslationFilesOnly)
Environment.Exit(0);
}

args.Handled = false;
Expand Down Expand Up @@ -1081,7 +1090,11 @@ private Task InteractionHandler(DiscordClient client, InteractionCreateEventArgs
GuildLocale = e.Interaction.GuildLocale,
AppPermissions = e.Interaction.AppPermissions,
Entitlements = e.Interaction.Entitlements,
EntitlementSkuIds = e.Interaction.EntitlementSkuIds
EntitlementSkuIds = e.Interaction.EntitlementSkuIds,
UserId = e.Interaction.User.Id,
GuildId = e.Interaction.GuildId,
MemberId = e.Interaction.GuildId is not null ? e.Interaction.User.Id : null,
ChannelId = e.Interaction.ChannelId
};

try
Expand Down Expand Up @@ -1340,7 +1353,12 @@ private Task ContextMenuHandler(DiscordClient client, ContextMenuInteractionCrea
_ = Task.Run(async () =>
{
//Creates the context
var context = new ContextMenuContext
var context = new ContextMenuContext(e.Type switch
{
ApplicationCommandType.User => DisCatSharpCommandType.UserCommand,
ApplicationCommandType.Message => DisCatSharpCommandType.MessageCommand,
_ => throw new ArgumentOutOfRangeException(nameof(e.Type), "Unknown context menu type")
})
{
Interaction = e.Interaction,
Channel = e.Interaction.Channel,
Expand All @@ -1359,7 +1377,11 @@ private Task ContextMenuHandler(DiscordClient client, ContextMenuInteractionCrea
GuildLocale = e.Interaction.GuildLocale,
AppPermissions = e.Interaction.AppPermissions,
Entitlements = e.Interaction.Entitlements,
EntitlementSkuIds = e.Interaction.EntitlementSkuIds
EntitlementSkuIds = e.Interaction.EntitlementSkuIds,
UserId = e.Interaction.User.Id,
GuildId = e.Interaction.GuildId,
MemberId = e.Interaction.GuildId is not null ? e.Interaction.User.Id : null,
ChannelId = e.Interaction.ChannelId
};

try
Expand Down
Original file line number Diff line number Diff line change
@@ -1,63 +1,62 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Globalization;
using System.Threading.Tasks;

using DisCatSharp.ApplicationCommands.Context;
using DisCatSharp.ApplicationCommands.Entities;
using DisCatSharp.ApplicationCommands.Enums;
using DisCatSharp.Entities;
using DisCatSharp.Entities.Core;
using DisCatSharp.Enums;
using DisCatSharp.Enums.Core;

using Sentry;

namespace DisCatSharp.ApplicationCommands.Attributes;

/// <summary>
/// Defines a cooldown for this command. This allows you to define how many times can users execute a specific command
/// </summary>
/// <remarks>
/// Defines a cooldown for this command. This means that users will be able to use the command a specific number of times before they have to wait to use it again.
/// </remarks>
/// <param name="maxUses">Number of times the command can be used before triggering a cooldown.</param>
/// <param name="resetAfter">Number of seconds after which the cooldown is reset.</param>
/// <param name="bucketType">Type of cooldown bucket. This allows controlling whether the bucket will be cooled down per user, guild, member, channel, and/or globally.</param>
/// <param name="cooldownResponderType">The responder type used to respond to cooldown ratelimit hits.</param>
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public sealed class ContextMenuCooldownAttribute : ApplicationCommandCheckBaseAttribute, ICooldown<BaseContext, ContextMenuCooldownBucket>
public sealed class ContextMenuCooldownAttribute(int maxUses, double resetAfter, CooldownBucketType bucketType, Type? cooldownResponderType = null) : ApplicationCommandCheckBaseAttribute, ICooldown<BaseContext, CooldownBucket>
{
/// <summary>
/// Gets the maximum number of uses before this command triggers a cooldown for its bucket.
/// </summary>
public int MaxUses { get; }
public int MaxUses { get; } = maxUses;

/// <summary>
/// Gets the time after which the cooldown is reset.
/// </summary>
public TimeSpan Reset { get; }
public TimeSpan Reset { get; } = TimeSpan.FromSeconds(resetAfter);

/// <summary>
/// Gets the type of the cooldown bucket. This determines how cooldowns are applied.
/// </summary>
public CooldownBucketType BucketType { get; }

/// <summary>
/// Gets the cooldown buckets for this command.
/// </summary>
internal readonly ConcurrentDictionary<string, ContextMenuCooldownBucket> Buckets;
public CooldownBucketType BucketType { get; } = bucketType;

/// <summary>
/// Defines a cooldown for this command. This means that users will be able to use the command a specific number of times before they have to wait to use it again.
/// Gets the responder type.
/// </summary>
/// <param name="maxUses">Number of times the command can be used before triggering a cooldown.</param>
/// <param name="resetAfter">Number of seconds after which the cooldown is reset.</param>
/// <param name="bucketType">Type of cooldown bucket. This allows controlling whether the bucket will be cooled down per user, guild, channel, or globally.</param>
public ContextMenuCooldownAttribute(int maxUses, double resetAfter, CooldownBucketType bucketType)
{
this.MaxUses = maxUses;
this.Reset = TimeSpan.FromSeconds(resetAfter);
this.BucketType = bucketType;
this.Buckets = new();
}
public Type? ResponderType { get; } = cooldownResponderType;

/// <summary>
/// Gets a cooldown bucket for given command context.
/// </summary>
/// <param name="ctx">Command context to get cooldown bucket for.</param>
/// <returns>Requested cooldown bucket, or null if one wasn't present.</returns>
public ContextMenuCooldownBucket GetBucket(BaseContext ctx)
public CooldownBucket GetBucket(BaseContext ctx)
{
var bid = this.GetBucketId(ctx, out _, out _, out _);
this.Buckets.TryGetValue(bid, out var bucket);
return bucket;
var bid = this.GetBucketId(ctx, out _, out _, out _, out _);
ctx.Client.CommandCooldownBuckets.TryGetValue(bid, out var bucket);
return bucket!;
}

/// <summary>
Expand All @@ -68,7 +67,7 @@ public ContextMenuCooldownBucket GetBucket(BaseContext ctx)
public TimeSpan GetRemainingCooldown(BaseContext ctx)
{
var bucket = this.GetBucket(ctx);
return bucket == null
return bucket == null!
? TimeSpan.Zero
: bucket.RemainingUses > 0
? TimeSpan.Zero
Expand All @@ -82,8 +81,9 @@ public TimeSpan GetRemainingCooldown(BaseContext ctx)
/// <param name="userId">ID of the user with which this bucket is associated.</param>
/// <param name="channelId">ID of the channel with which this bucket is associated.</param>
/// <param name="guildId">ID of the guild with which this bucket is associated.</param>
/// <param name="memberId">ID of the member with which this bucket is associated.</param>
/// <returns>Calculated bucket ID.</returns>
private string GetBucketId(BaseContext ctx, out ulong userId, out ulong channelId, out ulong guildId)
private string GetBucketId(BaseContext ctx, out ulong userId, out ulong channelId, out ulong guildId, out ulong memberId)
{
userId = 0ul;
if ((this.BucketType & CooldownBucketType.User) != 0)
Expand All @@ -92,14 +92,16 @@ private string GetBucketId(BaseContext ctx, out ulong userId, out ulong channelI
channelId = 0ul;
if ((this.BucketType & CooldownBucketType.Channel) != 0)
channelId = ctx.Channel.Id;
if ((this.BucketType & CooldownBucketType.Guild) != 0 && ctx.Guild == null)
channelId = ctx.Channel.Id;

guildId = 0ul;
if (ctx.Guild != null && (this.BucketType & CooldownBucketType.Guild) != 0)
if (ctx.Guild is not null && (this.BucketType & CooldownBucketType.Guild) != 0)
guildId = ctx.Guild.Id;

var bid = CooldownBucket.MakeId(userId, channelId, guildId);
memberId = 0ul;
if (ctx.Guild is not null && ctx.Member is not null && (this.BucketType & CooldownBucketType.Member) != 0)
memberId = ctx.Member.Id;

var bid = CooldownBucket.MakeId(ctx.FullCommandName, ctx.Interaction.Data.Id.ToString(CultureInfo.InvariantCulture), userId, channelId, guildId, memberId);
return bid;
}

Expand All @@ -109,29 +111,36 @@ private string GetBucketId(BaseContext ctx, out ulong userId, out ulong channelI
/// <param name="ctx">The command context.</param>
public override async Task<bool> ExecuteChecksAsync(BaseContext ctx)
{
var bid = this.GetBucketId(ctx, out var usr, out var chn, out var gld);
if (!this.Buckets.TryGetValue(bid, out var bucket))
{
bucket = new(this.MaxUses, this.Reset, usr, chn, gld);
this.Buckets.AddOrUpdate(bid, bucket, (k, v) => bucket);
}
var bid = this.GetBucketId(ctx, out var usr, out var chn, out var gld, out var mem);
if (ctx.Client.CommandCooldownBuckets.TryGetValue(bid, out var bucket))
return await this.RespondRatelimitHitAsync(ctx, await bucket.DecrementUseAsync(ctx), bucket);

bucket = new(this.MaxUses, this.Reset, ctx.FullCommandName, ctx.Interaction.Data.Id.ToString(CultureInfo.InvariantCulture), usr, chn, gld, mem);
ctx.Client.CommandCooldownBuckets.AddOrUpdate(bid, bucket, (k, v) => bucket);

return await bucket.DecrementUseAsync().ConfigureAwait(false);
return await this.RespondRatelimitHitAsync(ctx, await bucket.DecrementUseAsync(ctx), bucket);
}
}

/// <summary>
/// Represents a cooldown bucket for commands.
/// </summary>
public sealed class ContextMenuCooldownBucket : CooldownBucket
{
internal ContextMenuCooldownBucket(int maxUses, TimeSpan resetAfter, ulong userId = 0, ulong channelId = 0, ulong guildId = 0)
: base(maxUses, resetAfter, userId, channelId, guildId)
{ }
/// <inheritdoc/>
public async Task<bool> RespondRatelimitHitAsync(BaseContext ctx, bool noHit, CooldownBucket bucket)
{
if (noHit)
return true;

/// <summary>
/// Returns a string representation of this command cooldown bucket.
/// </summary>
/// <returns>String representation of this command cooldown bucket.</returns>
public override string ToString() => $"Context Menu Command bucket {this.BucketId}";
if (this.ResponderType is null)
{
if (ApplicationCommandsExtension.Configuration.AutoDefer)
await ctx.EditResponseAsync(new DiscordWebhookBuilder().WithContent($"Error: Ratelimit hit\nTry again {bucket.ResetsAt.Timestamp()}"));
else
await ctx.CreateResponseAsync(InteractionResponseType.ChannelMessageWithSource, new DiscordInteractionResponseBuilder().WithContent($"Error: Ratelimit hit\nTry again {bucket.ResetsAt.Timestamp()}").AsEphemeral());

return false;
}

var providerMethod = this.ResponderType.GetMethod(nameof(ICooldownResponder.Responder));
var providerInstance = Activator.CreateInstance(this.ResponderType);
await ((Task)providerMethod.Invoke(providerInstance, [ctx])).ConfigureAwait(false);

return false;
}
}
Loading
Loading