ari

discord music bot
git clone git://git.bain.cz/ari.git
Log | Files | Refs

commit 0354f21ea881740c0c3b3d23bffd522907e8512e
Author: bain <bain@bain.cz>
Date:   Sun, 21 Aug 2022 23:12:25 +0200

Initial commit

Diffstat:
A.gitignore | 6++++++
Aari/__main__.py | 4++++
Aari/bot.py | 79+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aari/cache.py | 158+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aari/constants.py | 17+++++++++++++++++
Aari/messages.py | 167+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aari/player.py | 131+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aari/queue.py | 71+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Aimages/emoji_play.png | 0
Aimages/emoji_pushtofront.png | 0
Aimages/emoji_warning.png | 0
Arequirements.txt | 4++++
12 files changed, 637 insertions(+), 0 deletions(-)

diff --git a/.gitignore b/.gitignore @@ -0,0 +1,6 @@ +/venv/ +/pyrightconfig.json +/.env +/.vimspector.json +*__pycache__* +/cache/ diff --git a/ari/__main__.py b/ari/__main__.py @@ -0,0 +1,4 @@ +from .bot import bot +from .constants import BOT_TOKEN + +bot.run(BOT_TOKEN, log_handler=None) diff --git a/ari/bot.py b/ari/bot.py @@ -0,0 +1,79 @@ +import logging +from typing import Dict + +import discord +from discord import app_commands + +from .player import Player + +from .constants import BAINS_POSIT_NOTE, MUSIC_CACHE, Emoji +from .messages import MessageHandler +from .cache import Cache + +discord.utils.setup_logging() +logging.getLogger("ari").setLevel(logging.DEBUG) + +logger = logging.getLogger(__name__) + + +intents = discord.Intents.default() +intents.message_content = True + + +class Bot(discord.Client): + def __init__(self, intents: discord.Intents) -> None: + super().__init__(intents=intents) + self.tree = app_commands.CommandTree(self) + self.message_handler = MessageHandler(self) + self.music_cache = Cache(self.http, MUSIC_CACHE, max_size=30000000) + self.players: Dict[int, Player] = {} + + async def setup_hook(self) -> None: + return + await self.tree.sync() + + async def on_message(self, message: discord.Message): + await self.message_handler.handle_message(message) + + async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent): + await self.message_handler.handle_reaction_add(payload) + + +bot = Bot(intents) + + +@bot.tree.command(name="skip", description="Skip specified number of songs") +@app_commands.describe(number="Number of songs to skip") +@app_commands.guild_only() +async def skip(interaction: discord.Interaction, number: int = 1): + assert interaction.guild is not None + if number < 1: + await interaction.response.send_message( + f"{Emoji.error} The number of songs to skip must be bigger or equal to 1", + ephemeral=True, + ) + return + player = bot.players.get(interaction.guild.id) + if player is None or not player.is_running(): + await interaction.response.send_message( + f"{Emoji.error} I am not currently playing anything", ephemeral=True + ) + else: + player.skip(number) + await interaction.response.send_message( + f"Skipping {number} song{'s' if number > 1 else ''}..." + ) + + +@bot.tree.command(name="stop", description="Stop playing and disconnect") +@app_commands.guild_only() +async def stop(interaction: discord.Interaction): + assert interaction.guild is not None + player = bot.players.get(interaction.guild.id) + if player is None or not player.is_running(): + await interaction.response.send_message( + f"{Emoji.error} I am not currently playing anything", ephemeral=True + ) + else: + player.stop() + await interaction.response.send_message(f"Disconnecting...") diff --git a/ari/cache.py b/ari/cache.py @@ -0,0 +1,158 @@ +import asyncio +import hashlib +import time +from typing import Dict, Tuple, List + +import aiohttp +import yt_dlp as youtube_dl +import os +import logging + +logger = logging.getLogger(__name__) + + +class CacheError(Exception): + """Error while making sure a file exists""" + + pass + + +def link_hash(link: str) -> str: + return hashlib.md5(link.encode()).hexdigest() + + +def extract_info_yt(link): + with youtube_dl.YoutubeDL({"format": "bestaudio"}) as ydl: + return ydl.extract_info(link, download=False) + + +async def get_size(link: str): + try: + if link.startswith("https://youtu.be/"): + return (await asyncio.to_thread(extract_info_yt, link)).get( # type: ignore + "filesize", -1 + ) + else: + async with aiohttp.ClientSession() as session: + async with session.head(link, allow_redirects=True) as response: + return response.content_length or -1 + except youtube_dl.DownloadError: + logger.info("failed to fetch video size") + return -1 + + +class Cache: + def __init__(self, http, directory: str, max_size: int = 500 * 1000**2): + self._http = http + self._directory: str = directory + self._max_size: int = max_size + self._current_size: int = 0 + self._links: Dict[str, Tuple[int, int, str]] = {} + self.locked_files: List[str] = [] + self.cache_lock = asyncio.Lock() + + os.makedirs(self._directory, exist_ok=True) + + async def _try_free_space(self, size: int) -> bool: + """NEEDS TO BE GUARDED WITH CACHE LOCK""" + logger.debug(f"freeing {size} bytes of space") + if size > self._max_size: + return False + + to_remove = [] + successful = False + for link in sorted(self._links, key=lambda x: x[0]): + if self._links[link][2] not in self.locked_files: + os.remove(f"{self._directory}/{self._links[link][2]}") + self._current_size -= self._links[link][1] + to_remove.append(link) + if size + self._current_size <= self._max_size: + successful = True + break + if to_remove: + logger.info(f"removed {len(to_remove)} files") + for link in to_remove: + del self._links[link] + + return successful + + async def _download_youtube_file(self, link: str): + success = False + with youtube_dl.YoutubeDL( + { + "outtmpl": f"{self._directory}/{link_hash(link)}", + "format": "bestaudio", + "updatetime": False, + "ratelimit": 5000000, + } + ) as ydl: + for _ in range(3): + ex = asyncio.to_thread(ydl.download, (link,)) + try: + await asyncio.wait_for(ex, 10) + except ( + asyncio.TimeoutError, + youtube_dl.utils.DownloadError, + ): + pass + else: + success = True + break + finally: + # clean up potential leftovers from ytdl + if os.path.exists(f"{self._directory}/{link_hash(link)}.part"): + os.remove(f"{self._directory}/{link_hash(link)}.part") + if not success: + raise CacheError("Youtube download failed") + + async def _download_file(self, link: str): + """NEEDS TO BE GUARDED WITH CACHE LOCK""" + size = await get_size(link) + if size < 0: + raise CacheError("Music file size unknown") + + logger.info(f"Downloading video of size {size / 1000:.2f}kb") + if (size < self._max_size - self._current_size) or ( + await self._try_free_space(size) + ): + logger.debug(f"size: {self._max_size - self._current_size}") + if link.startswith("https://youtu.be/"): + await self._download_youtube_file(link) + else: + try: + data = await asyncio.wait_for(self._http.get_from_cdn(link), 10) + except asyncio.TimeoutError: + raise CacheError("Discord download is taking too long") + with open(f"{self._directory}/{link_hash(link)}", "wb+") as f: + f.write(data) + self._current_size += size # reserve space + + self._links[link] = (time.time_ns(), size, link_hash(link)) + return True + + raise CacheError("Music file size too large or unknown") + + async def ensure_existence(self, link: str): + async with self.cache_lock: + fp = f"{self._directory}/{link_hash(link)}" + if self._links.get(link): + return CacheContextManager(self, fp) + else: + await self._download_file(link) + return CacheContextManager(self, fp) + + +class CacheContextManager: + def __init__(self, cache: Cache, file: str): + self._cache: Cache = cache + self._file: str = file + + def __enter__(self) -> str: + self._cache.locked_files.append(self._file) + if not os.path.exists(self._file): + self._cache.locked_files.remove(self._file) + raise CacheError("cannot lock file; it no longer exists") + return self._file + + def __exit__(self, exc_type, exc_val, exc_tb): + self._cache.locked_files.remove(self._file) diff --git a/ari/constants.py b/ari/constants.py @@ -0,0 +1,17 @@ +import os +from typing import NamedTuple +import discord + +BAINS_POSIT_NOTE = discord.Object(id=630144683359862814) + + +class Emoji(NamedTuple): + play = "<:play:1010298926550749247>" + skip_to = "<:push_to_front:1010299358174007396>" + download_error = "<:download_error:1010299866246807662>" + error = "<:error:755487487807324230>" + + +PRELOAD = bool(int(os.getenv("PRELOAD", "1"))) +MUSIC_CACHE = os.getenv("MUSIC_CACHE", "cache") +BOT_TOKEN = os.getenv("BOT_TOKEN", "invalid") diff --git a/ari/messages.py b/ari/messages.py @@ -0,0 +1,167 @@ +import time +import asyncio +import logging +import re +from typing import TYPE_CHECKING, Dict, Set +from bidict import bidict, MutableBidirectionalMapping + +from .constants import Emoji +from .queue import Content, QueueItem +from .player import Player + +if TYPE_CHECKING: + from typing import Dict, Set + from .bot import Bot + +_youtube_regex = re.compile( + r"http(?:s?):\/\/(?:www\.)?youtu(?:be\.com\/watch\?v=|\.be\/)([\w\-\_]*)(&(amp;)?[\w\?=]*)?" +) + +import discord + +logger = logging.getLogger(__name__) + + +class MessageHandler: + def __init__(self, client) -> None: + self.client: "Bot" = client + self.skippables: MutableBidirectionalMapping[int, QueueItem] = bidict() + self.requests = { + Emoji.play: self.handle_play_request, + Emoji.download_error: self.show_errors, + } + + async def handle_message(self, message: discord.Message) -> None: + if message.guild is None: + return + if _youtube_regex.search(message.content) is not None or any( + [ + "audio" in a.content_type + for a in message.attachments + if a.content_type is not None + ] + ): + logger.debug("message %s has a youtuble link", message.id) + await message.add_reaction(Emoji.play) + + async def handle_reaction_add( + self, payload: discord.RawReactionActionEvent + ) -> None: + assert self.client.user is not None + if payload.guild_id is None: + return # the message must be in a guild + if payload.user_id == self.client.user.id: + return # ignore self + + req = self.requests.get(str(payload.emoji)) + if req is not None: + await req(payload) + + async def get_or_fetch_message( + self, message_id: int, channel_id: int + ) -> discord.Message: + message = next( + filter(lambda m: m.id == message_id, self.client.cached_messages), + None, + ) + if message is None: + logger.debug("message %s was not cached, fetching...", message_id) + channel = self.client.get_channel( + channel_id + ) or await self.client.fetch_channel(channel_id) + assert isinstance(channel, discord.TextChannel) + message = await channel.fetch_message(message_id) + + # manually add message to the client's message cache. + # this way the client can refresh the message when it is edited + cache = self.client._connection._messages + if cache is not None: + cache.append(message) + return message + + async def handle_play_request(self, payload: discord.RawReactionActionEvent): + logger.info("play request on message %s", payload.message_id) + + # check cache + message = await self.get_or_fetch_message( + payload.message_id, payload.channel_id + ) + + assert message.guild is not None + user = message.guild.get_member(payload.user_id) + if user is None: + user = await message.guild.fetch_member(payload.user_id) + + # get all videos and attachments from the message + playable = [] + for attachment in message.attachments: + if attachment.content_type in ("audio/mpeg", "audio/ogg", "audio/wave"): + playable.append(attachment.url) + pos = 0 + while pos < len(message.content): + match = _youtube_regex.search(message.content, pos) + if match is None: + break + pos = match.span()[1] + playable.append("https://youtu.be/" + match.groups()[0]) + + logger.debug("adding %s songs to queue", len(playable)) + if not playable: + await user.send( + f"{Emoji.error} There are no playable videos/music in the message" + ) + await message.clear_reaction(Emoji.play) + return + + if ( + next( + filter( + lambda x: str(x.emoji) == Emoji.download_error, message.reactions + ), + None, + ) + is not None + ): + # clear any hanging errors on the message + # caused only when the message is played again before + # the reaction is automatically removed after a timeout + logger.debug("clearing download error reaction") + await message.clear_reaction(Emoji.download_error) + await message.remove_reaction(Emoji.play, discord.Object(id=payload.user_id)) + + # push video ids to the player queue + player = self.client.players.get(user.guild.id) + if player is None or not player.is_running(): + player = await Player.create(self.client, user) + if player is None: + await user.send( + f"{Emoji.error} Failed to connect to voice. Are you in a voice channel?", + ) + return + self.client.players[user.guild.id] = player + for id in playable: + player.queue.push(Content(message, id)) + + if not player.is_running(): + asyncio.create_task(player.run()) + + async def show_errors(self, payload: discord.RawReactionActionEvent): + assert payload.guild_id is not None + dm = await self.client.create_dm(discord.Object(id=payload.user_id)) + player = self.client.players.get(payload.guild_id) + if player is None or not player.is_running(): + # error emojis should be only visible when a player is running + # this was probably fired in the split second when the player was + # turning off + return + message = f"{Emoji.download_error} ***Sorry, I was not able to play the following songs:***\n```\n" + for error in player.errored_songs: + message += f" - {error}\n" + message += "```" + await dm.send(message) + message = await self.get_or_fetch_message( + payload.message_id, payload.channel_id + ) + await message.remove_reaction( + Emoji.download_error, discord.Object(id=payload.user_id) + ) diff --git a/ari/player.py b/ari/player.py @@ -0,0 +1,131 @@ +from yt_dlp import os +from .cache import CacheError +from .constants import Emoji, PRELOAD +from .queue import Queue, QueueItem +import discord +import logging +import asyncio + +from typing import TYPE_CHECKING, Set + +if TYPE_CHECKING: + from .bot import Bot +logger = logging.getLogger(__name__) + + +class Player: + def __init__(self, client: "Bot", voice_client: discord.VoiceClient) -> None: + self.queue = Queue() + self.voice = voice_client + self.client = client + self.errored: Set[discord.Message] = set() + self.errored_songs: Set[str] = set() + self._running = False + self._skip = 0 + + @classmethod + async def create(cls, client: "Bot", user: discord.Member): + if not user.voice or not user.voice.channel: + return None + + logger.debug("creating player") + voice_client = user.guild.voice_client + assert voice_client is None or isinstance(voice_client, discord.VoiceClient) + if not voice_client: + voice_client = await user.voice.channel.connect() # type: ignore + else: + if not voice_client.is_connected(): + try: + voice_client.stop() + logger.debug("reusing existing voice client") + await voice_client.connect(reconnect=True, timeout=10) + except asyncio.TimeoutError or discord.ConnectionClosed: + logger.warning(f"failed to connect") + return None + + return cls(client, voice_client) + + async def run(self): + logger.info(f"{hash(self)} running player") + self._running = True + while not self.queue.empty() and self.voice.is_connected() and self._running: + music = self.queue.pop() + if self._skip > 0: + logger.debug(f"skips remaining {self._skip-1}") + self._skip -= 1 + continue + assert music is not None + logger.debug("playing %s", music) + try: + with await self.client.music_cache.ensure_existence( + music.content.video_id + ) as file: + # ensuring existence can take a long time + if self.voice.is_connected(): + self.voice.play(discord.FFmpegPCMAudio(file)) + + tried_preload = False + while ( + self.voice.is_connected() + and self.voice.is_playing() + and self._running + ): + if PRELOAD and not tried_preload: + tried_preload = await self.preload() + if self._skip > 0: + logger.debug("skipping currently playing") + self.voice.stop() + self._skip -= 1 + break + await asyncio.sleep(1) + except CacheError: + await self.add_error(music) + self.errored.add(music.content.message) + self.errored_songs.add(music.content.video_id) + except Exception: + logger.exception(f"{hash(self)}: exception while playing") + break + self._running = False + await self.voice.disconnect() + await self.cleanup_errors() + logger.info(f"{hash(self)} player shutdown") + + async def cleanup_errors(self): + for message in self.errored: + try: + await message.clear_reaction(Emoji.download_error) + except discord.HTTPException: + # the message could already be gone + pass + + async def add_error(self, item: QueueItem): + try: + await item.content.message.add_reaction(Emoji.download_error) + except discord.HTTPException as e: + # the message could be already gone + logger.warning( + "could not add error to message %s: %s", item.content.message.id, e + ) + + def is_running(self): + return self._running + + def skip(self, num: int): + self._skip += num + + def stop(self): + self._running = False + + async def preload(self): + preload = self.queue.peek() + if preload is not None: + # we ignore the context manager, thus we're not actually + # locking the file, just preloading it + try: + await self.client.music_cache.ensure_existence(preload.content.video_id) + return True + except CacheError: + # error silently, maybe we can free up space by letting + # go of the currently playing song + pass + return False diff --git a/ari/queue.py b/ari/queue.py @@ -0,0 +1,71 @@ +from collections import deque +from dataclasses import dataclass +from threading import Lock +from typing import Deque, NamedTuple, Optional +import logging + +import discord + +from .constants import Emoji + +logger = logging.getLogger(__name__) + + +class LimitReached(Exception): + pass + + +class Content(NamedTuple): + message: discord.Message + video_id: str + + +@dataclass(eq=False) +class QueueItem: + content: Content + next: Optional["QueueItem"] + queue: Optional[int] + + +class Queue: + def __init__(self, max_length: int = 50) -> None: + self._lock = Lock() + self._front = None + self._back = None + self.length = 0 + + def push(self, item: Content) -> QueueItem: + with self._lock: + if self._front is None: + self._front = QueueItem(content=item, queue=id(self), next=None) + self._back = self._front + else: + assert self._back is not None + self._back.next = QueueItem(content=item, queue=id(self), next=None) + self._back = self._back.next + self.length += 1 + return self._back + + def pop(self) -> Optional[QueueItem]: + with self._lock: + item = self._front + if item is not None: + self._front = item.next + item.queue = None # item is no longer a part of the queue + self.length -= 1 + return item + + def skip_to(self, item: QueueItem): + with self._lock: + if item.queue != id(self): + raise ValueError("item must be from this queue") + self._front = item + + def peek(self) -> Optional[QueueItem]: + return self._front + + def __contains__(self, item: QueueItem) -> bool: + return item.queue == id(self) + + def empty(self) -> bool: + return self.length == 0 diff --git a/images/emoji_play.png b/images/emoji_play.png Binary files differ. diff --git a/images/emoji_pushtofront.png b/images/emoji_pushtofront.png Binary files differ. diff --git a/images/emoji_warning.png b/images/emoji_warning.png Binary files differ. diff --git a/requirements.txt b/requirements.txt @@ -0,0 +1,4 @@ +discord.py[voice] +bidict +yt_dlp +