diff --git a/musicbot/audiocontroller.py b/musicbot/audiocontroller.py index 65963402..ce9a8ad4 100644 --- a/musicbot/audiocontroller.py +++ b/musicbot/audiocontroller.py @@ -1,4 +1,5 @@ from enum import Enum +from inspect import isawaitable from typing import TYPE_CHECKING, Coroutine, Optional, List, Tuple import discord @@ -8,6 +9,7 @@ from musicbot import linkutils, utils from musicbot.playlist import Playlist from musicbot.songinfo import Song +from musicbot.utils import compare_components # avoiding circular import if TYPE_CHECKING: @@ -15,6 +17,7 @@ _cached_downloaders: List[Tuple[dict, yt_dlp.YoutubeDL]] = [] +_not_provided = object() class PauseState(Enum): @@ -23,6 +26,24 @@ class PauseState(Enum): RESUMED = "Resumed playback :arrow_forward:" +class LoopState(Enum): + INVALID = "Invalid loop mode!" + ENABLED = "Loop enabled :arrows_counterclockwise:" + DISABLED = "Loop disabled :x:" + + +class MusicButton(discord.ui.Button): + def __init__(self, callback, **kwargs): + super().__init__(**kwargs) + self._callback = callback + + async def callback(self, inter): + await inter.response.defer() + res = self._callback(inter) + if isawaitable(res): + await res + + class AudioController(object): """Controls the playback of audio and the sequential playing of the songs. @@ -37,6 +58,7 @@ def __init__(self, bot: "MusicBot", guild: discord.Guild): self.bot = bot self.playlist = Playlist() self.current_song = None + self._next_song = None self.guild = guild sett = bot.settings[guild] @@ -46,6 +68,9 @@ def __init__(self, bot: "MusicBot", guild: discord.Guild): self.command_channel: Optional[discord.abc.Messageable] = None + self.last_message = None + self.last_view = None + # according to Python documentation, we need # to keep strong references to all tasks self._tasks = set() @@ -102,6 +127,107 @@ async def fetch_song_info(self, song: Song): ) song.update(info) + def make_view(self): + if not self.is_active(): + return None + + view = discord.ui.View(timeout=None) + is_empty = len(self.playlist) == 0 + + prev_button = MusicButton( + lambda _: self.prev_song(), + disabled=not self.playlist.has_prev(), + emoji="⏮️", + ) + view.add_item(prev_button) + + pause_button = MusicButton( + lambda _: self.pause(), + emoji="⏸️" if self.guild.voice_client.is_playing() else "▶️", + ) + view.add_item(pause_button) + + next_button = MusicButton( + lambda _: self.next_song(), + disabled=not self.playlist.has_next(), + emoji="⏭️", + ) + view.add_item(next_button) + + loop_button = MusicButton( + lambda _: self.loop(), + disabled=is_empty, + emoji="🔁", + label="Loop: " + self.playlist.loop, + ) + view.add_item(loop_button) + + np_button = MusicButton( + self.current_song_callback, + row=1, + disabled=self.current_song is None, + emoji="💿", + ) + view.add_item(np_button) + + shuffle_button = MusicButton( + lambda _: self.playlist.shuffle(), + row=1, + disabled=is_empty, + emoji="🔀", + ) + view.add_item(shuffle_button) + + queue_button = MusicButton( + self.queue_callback, row=1, disabled=is_empty, emoji="📜" + ) + view.add_item(queue_button) + + stop_button = MusicButton( + lambda _: self.stop_player(), + row=1, + emoji="⏹️", + style=discord.ButtonStyle.red, + ) + view.add_item(stop_button) + + self.last_view = view + + return view + + async def current_song_callback(self, inter): + await (await inter.client.get_application_context(inter)).send( + embed=self.current_song.info.format_output(config.SONGINFO_SONGINFO), + ) + + async def queue_callback(self, inter): + await (await inter.client.get_application_context(inter)).send( + embed=self.playlist.queue_embed(), + ) + + async def update_view(self, view=_not_provided): + msg = self.last_message + if not msg: + return + old_view = self.last_view + if view is _not_provided: + view = self.make_view() + if view is None: + self.last_message = None + elif compare_components(old_view.to_components(), view.to_components()): + return + try: + await msg.edit(view=view) + except discord.HTTPException as e: + if e.code == 50027: # Invalid Webhook Token + try: + self.last_message = await msg.channel.fetch_message(msg.id) + await self.update_view(view) + except discord.NotFound: + self.last_message = None + else: + print("Failed to update view:", e) + def is_active(self) -> bool: client = self.guild.voice_client return client is not None and (client.is_playing() or client.is_paused()) @@ -121,14 +247,36 @@ def pause(self): elif client.is_paused(): client.resume() return PauseState.RESUMED - else: - return PauseState.NOTHING_TO_PAUSE return PauseState.NOTHING_TO_PAUSE + def loop(self, mode=None): + if mode is None: + if self.playlist.loop == "off": + mode = "all" + else: + mode = "off" + + if mode not in ("all", "single", "off"): + return LoopState.INVALID + + self.playlist.loop = mode + + if mode == "off": + return LoopState.DISABLED + return LoopState.ENABLED + def next_song(self, error=None): """Invoked after a song is finished. Plays the next song if there is one.""" - next_song = self.playlist.next(self.current_song) + if self.is_active(): + self.guild.voice_client.stop() + return + + if self._next_song: + next_song = self._next_song + self._next_song = None + else: + next_song = self.playlist.next() self.current_song = None @@ -158,8 +306,6 @@ async def play_song(self, song: Song): self.playlist.add_name(song.info.title) self.current_song = song - self.playlist.playhistory.append(self.current_song) - self.guild.voice_client.play( discord.FFmpegPCMAudio( song.base_url, @@ -178,8 +324,6 @@ async def play_song(self, song: Song): embed=song.info.format_output(config.SONGINFO_NOW_PLAYING) ) - self.playlist.playque.popleft() - for song in list(self.playlist.playque)[: config.MAX_SONG_PRELOAD]: self.add_task(self.preload(song)) @@ -334,34 +478,30 @@ async def search_youtube(self, title: str) -> Optional[dict]: return r["entries"][0] - async def stop_player(self): + def stop_player(self): """Stops the player and removes all songs from the queue""" if not self.is_active(): return self.playlist.loop = "off" - self.playlist.next(self.current_song) + self.playlist.next() self.clear_queue() self.guild.voice_client.stop() - async def prev_song(self) -> bool: + def prev_song(self) -> bool: """Loads the last song from the history into the queue and starts it""" self.timer.cancel() self.timer = utils.Timer(self.timeout_handler) - if len(self.playlist.playhistory) == 0: + prev_song = self.playlist.prev() + if not prev_song: return False - prev_song = self.playlist.prev(self.current_song) - if not self.is_active(): - - if prev_song == "Dummy": - self.playlist.next(self.current_song) - return False - await self.play_song(prev_song) + self.add_task(self.play_song(prev_song)) else: + self._next_song = prev_song self.guild.voice_client.stop() return True @@ -396,7 +536,8 @@ async def uconnect(self, ctx): return False async def udisconnect(self): - await self.stop_player() + self.stop_player() + await self.update_view(None) if self.guild.voice_client is None: return False await self.guild.voice_client.disconnect(force=True) diff --git a/musicbot/bot.py b/musicbot/bot.py index 4a73ba07..a569bbb5 100644 --- a/musicbot/bot.py +++ b/musicbot/bot.py @@ -1,7 +1,7 @@ from typing import Dict, Union import discord -from discord.ext import bridge +from discord.ext import bridge, tasks from discord.ext.commands import DefaultHelpCommand from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker @@ -38,6 +38,11 @@ async def start(self, *args, **kwargs): await extract_legacy_settings(self) return await super().start(*args, **kwargs) + async def close(self): + for audiocontroller in self.audio_controllers.values(): + await audiocontroller.udisconnect() + return await super().close() + async def on_ready(self): self.settings.update(await GuildSettings.load_many(self, self.guilds)) @@ -47,10 +52,18 @@ async def on_ready(self): print(config.STARTUP_COMPLETE_MESSAGE) + if not self.update_views.is_running(): + self.update_views.start() + async def on_guild_join(self, guild): print(guild.name) await self.register(guild) + @tasks.loop(seconds=1) + async def update_views(self): + for audiocontroller in self.audio_controllers.values(): + await audiocontroller.update_view() + def add_command(self, command): # fix empty description # https://github.com/Pycord-Development/pycord/issues/1619 @@ -74,6 +87,13 @@ async def get_prefix( async def get_application_context(self, interaction): return await super().get_application_context(interaction, ApplicationContext) + async def process_application_commands(self, inter): + if not inter.guild: + await inter.response.send_message(config.NO_GUILD_MESSAGE) + return + + await super().process_application_commands(inter) + async def process_commands(self, message: discord.Message): if message.author.bot: return @@ -124,8 +144,18 @@ class Context(bridge.BridgeContext): guild: discord.Guild async def send(self, *args, **kwargs): + audiocontroller = self.bot.audio_controllers[self.guild] + await audiocontroller.update_view(None) + view = audiocontroller.make_view() + if view: + kwargs["view"] = view # use `respond` for compatibility - return await self.respond(*args, **kwargs) + res = await self.respond(*args, **kwargs) + if isinstance(res, discord.Interaction): + audiocontroller.last_message = await res.original_message() + else: + audiocontroller.last_message = res + return res class ExtContext(bridge.BridgeExtContext, Context): diff --git a/musicbot/commands/music.py b/musicbot/commands/music.py index 35563aee..7035662a 100644 --- a/musicbot/commands/music.py +++ b/musicbot/commands/music.py @@ -4,6 +4,7 @@ from config import config from musicbot import linkutils, utils from musicbot.bot import MusicBot, Context +from musicbot.playlist import PlaylistError class Music(commands.Cog): @@ -75,22 +76,8 @@ async def _loop(self, ctx: Context, mode=None): await ctx.send("No songs in queue!") return - if mode is None: - if audiocontroller.playlist.loop == "off": - mode = "all" - else: - mode = "off" - - if mode not in ("all", "single", "off"): - await ctx.send("Invalid loop mode!") - return - - audiocontroller.playlist.loop = mode - - if mode in ("all", "single"): - await ctx.send("Loop enabled :arrows_counterclockwise:") - else: - await ctx.send("Loop disabled :x:") + result = audiocontroller.loop(mode) + await ctx.send(result.value) @bridge.bridge_command( name="shuffle", @@ -148,23 +135,7 @@ async def _queue(self, ctx: Context): if config.MAX_SONG_PRELOAD > 25: config.MAX_SONG_PRELOAD = 25 - embed = discord.Embed( - title=":scroll: Queue [{}]".format(len(playlist.playque)), - color=config.EMBED_COLOR, - ) - - for counter, song in enumerate( - list(playlist.playque)[: config.MAX_SONG_PRELOAD], start=1 - ): - embed.add_field( - name="{}.".format(str(counter)), - value="[{}]({})".format( - song.info.title or song.info.webpage_url, song.info.webpage_url - ), - inline=False, - ) - - await ctx.send(embed=embed) + await ctx.send(embed=playlist.queue_embed()) @bridge.bridge_command( name="stop", @@ -177,8 +148,7 @@ async def _stop(self, ctx: Context): return audiocontroller = ctx.bot.audio_controllers[ctx.guild] - audiocontroller.playlist.loop = "off" - await audiocontroller.stop_player() + audiocontroller.stop_player() await ctx.send("Stopped all sessions :octagonal_sign:") @bridge.bridge_command( @@ -194,10 +164,9 @@ async def _move(self, ctx: Context, src_pos: int, dest_pos: int): return try: audiocontroller.playlist.move(src_pos - 1, dest_pos - 1) - except IndexError: - await ctx.send("Wrong position") - return - await ctx.send("Moved") + await ctx.send("Moved ↔️") + except PlaylistError as e: + await ctx.send(e) @bridge.bridge_command( name="skip", @@ -218,7 +187,7 @@ async def _skip(self, ctx: Context): if not audiocontroller.is_active(): await ctx.send(config.QUEUE_EMPTY) return - ctx.guild.voice_client.stop() + audiocontroller.next_song() await ctx.send("Skipped current song :fast_forward:") @bridge.bridge_command( @@ -253,7 +222,7 @@ async def _prev(self, ctx: Context): audiocontroller.timer.cancel() audiocontroller.timer = utils.Timer(audiocontroller.timeout_handler) - if await audiocontroller.prev_song(): + if audiocontroller.prev_song(): await ctx.send("Playing previous song :track_previous:") else: await ctx.send("No previous track.") diff --git a/musicbot/playlist.py b/musicbot/playlist.py index 2c2ea42a..13649b04 100644 --- a/musicbot/playlist.py +++ b/musicbot/playlist.py @@ -2,10 +2,16 @@ from typing import Optional from collections import deque +from discord import Embed + from config import config from musicbot.songinfo import Song +class PlaylistError(Exception): + pass + + class Playlist: """Stores the youtube links of songs to be played and already played and offers basic operation on the queues""" @@ -30,43 +36,85 @@ def add_name(self, trackname: str): def add(self, track: Song): self.playque.append(track) - def next(self, song_played: Optional[Song]) -> Optional[Song]: + def has_next(self) -> bool: + return len(self.playque) >= (2 if self.loop == "off" else 1) - if self.loop == "single": - self.playque.appendleft(self.playhistory[-1]) - elif self.loop == "all": - self.playque.append(self.playhistory[-1]) + def has_prev(self) -> bool: + return len(self.playhistory if self.loop == "off" else self.playque) != 0 + def next(self) -> Optional[Song]: if len(self.playque) == 0: return None - if song_played != "Dummy": + if self.loop == "off": + self.playhistory.append(self.playque.popleft()) if len(self.playhistory) > config.MAX_HISTORY_LENGTH: self.playhistory.popleft() + if len(self.playque) != 0: + return self.playque[0] + else: + return None + + if self.loop == "all": + self.playque.rotate(-1) return self.playque[0] - def prev(self, current_song: Optional[Song]) -> Song: + def prev(self) -> Optional[Song]: + if self.loop == "off": + if len(self.playhistory) != 0: + song = self.playhistory.pop() + self.playque.appendleft(song) + return song + else: + return None - if current_song is None: - self.playque.appendleft(self.playhistory[-1]) - return self.playque[0] + if len(self.playque) == 0: + return None + + if self.loop == "all": + self.playque.rotate() - ind = self.playhistory.index(current_song) - prev = self.playhistory[ind - 1] - self.playque.appendleft(prev) - if current_song is not None: - self.playque.insert(1, current_song) - return prev + return self.playque[0] def shuffle(self): + first = self.playque.popleft() random.shuffle(self.playque) + self.playque.appendleft(first) def move(self, oldindex: int, newindex: int): - temp = self.playque[oldindex] + if oldindex < 0 or newindex < 0: + raise PlaylistError("Negative indexes are not supported.") + if oldindex == 0 or newindex == 0: + raise PlaylistError( + "Cannot move the first song since it's already playing." + ) + try: + temp = self.playque[oldindex] + except IndexError as e: + raise PlaylistError("Invalid position.") from e del self.playque[oldindex] self.playque.insert(newindex, temp) def empty(self): self.playque.clear() self.playhistory.clear() + + def queue_embed(self) -> Embed: + embed = Embed( + title=":scroll: Queue [{}]".format(len(self.playque)), + color=config.EMBED_COLOR, + ) + + for counter, song in enumerate( + list(self.playque)[: config.MAX_SONG_PRELOAD], start=1 + ): + embed.add_field( + name="{}.".format(str(counter)), + value="[{}]({})".format( + song.info.title or song.info.webpage_url, song.info.webpage_url + ), + inline=False, + ) + + return embed diff --git a/musicbot/utils.py b/musicbot/utils.py index ac64d90f..3a847633 100644 --- a/musicbot/utils.py +++ b/musicbot/utils.py @@ -147,6 +147,21 @@ def get_emoji(guild: Guild, string: str) -> Optional[Union[str, Emoji]]: return utils.get(guild.emojis, name=string) +def compare_components(obj1, obj2): + "compare two objects recursively but ignore custom_id in dicts" + if isinstance(obj1, (list, tuple)) and isinstance(obj2, (list, tuple)): + if len(obj1) != len(obj2): + return False + return all(compare_components(x1, x2) for x1, x2 in zip(obj1, obj2)) + elif isinstance(obj1, dict) and isinstance(obj2, dict): + obj1.pop("custom_id", None) + obj2.pop("custom_id", None) + if obj1.keys() != obj2.keys(): + return False + return all(compare_components(obj1[k], obj2[k]) for k in obj1) + return obj1 == obj2 + + class Timer: def __init__(self, callback: Callable[[], Awaitable]): self._callback = callback