Source code for

from __future__ import annotations

import contextlib
import gzip
import io
import pathlib
import random
import sys
from collections import defaultdict
from import Iterator
from dataclasses import dataclass

import aiohttp
import brotli  # type: ignore
import discord
import yaml
from dacite import from_dict

from pylav.compat import json
from pylav.constants.config import BROTLI_ENABLED, READ_CACHING_ENABLED
from pylav.constants.playlists import BUNDLED_PLAYLIST_IDS
from pylav.constants.regex import SQUARE_BRACKETS
from pylav.core.context import PyLavContext
from pylav.exceptions.playlist import InvalidPlaylistException
from pylav.helpers.singleton import SingletonCachedByKey
from pylav.logging import getLogger
from pylav.nodes.api.responses.track import Track
from import maybe_cached
from import CachedModel
from import PlaylistRow
from import TrackRow
from import DISCORD_BOT_TYPE
from pylav.type_hints.dict_typing import JSON_DICT_TYPE

LOGGER = getLogger("PyLav.Database.Playlist")

    from redbot.core.i18n import Translator  # type: ignore

    _ = Translator("PyLav", pathlib.Path(__file__))
except ImportError:
    Translator = None

    def _(string: str) -> str:
        return string

[docs] @dataclass(eq=True, slots=True, unsafe_hash=True, order=True, kw_only=True, frozen=True) class Playlist(CachedModel, metaclass=SingletonCachedByKey): id: int
[docs] def get_cache_key(self) -> str: return f"{}"
[docs] @maybe_cached async def exists(self) -> bool: """Check if the config exists. Returns ------- bool Whether the config exists. """ return await PlaylistRow.exists().where( ==
[docs] @maybe_cached async def fetch_all(self) -> JSON_DICT_TYPE: """Fetch all playlists from the database. Returns ------- dict The playlists. """ data = ( await,, PlaylistRow.tracks(TrackRow.encoded,, TrackRow.pluginInfo, load_json=True), PlaylistRow.scope,, PlaylistRow.url, ) .where( == .first() .output(load_json=True, nested=True) ) return data or { "id":, "name":, "tracks": [], "scope": PlaylistRow.scope.default, "author":, "url": PlaylistRow.url.default, }
[docs] @maybe_cached async def fetch_scope(self) -> int | None: """Fetch the scope of the playlist. Returns ------- int The scope of the playlist. """ response = ( await .where( == .first() .output(load_json=True, nested=True) ) return response["scope"] if response else PlaylistRow.scope.default
[docs] async def update_scope(self, scope: int): """Update the scope of the playlist. Parameters ---------- scope : int The new scope of the playlist. """ await PlaylistRow.insert(PlaylistRow(, scope=scope)).on_conflict( action="DO UPDATE",, values=[PlaylistRow.scope] ) await self.update_cache((self.fetch_scope, scope), (self.exists, True)) await self.invalidate_cache(self.fetch_all)
[docs] @maybe_cached async def fetch_author(self) -> int | None: """Fetch the author of the playlist. Returns ------- int The author of the playlist. """ response = ( await .where( == .first() .output(load_json=True, nested=True) ) return response["author"] if response else
[docs] async def update_author(self, author: int) -> None: """Update the author of the playlist. Parameters ---------- author : int The new author of the playlist. """ await PlaylistRow.insert(PlaylistRow(, author=author)).on_conflict( action="DO UPDATE",, values=[] ) await self.update_cache((self.fetch_author, author), (self.exists, True)) await self.invalidate_cache(self.fetch_all)
[docs] @maybe_cached async def fetch_name(self) -> str | None: """Fetch the name of the playlist. Returns ------- str The name of the playlist. """ response = ( await .where( == .first() .output(load_json=True, nested=True) ) return response["name"] if response else
[docs] async def update_name(self, name: str) -> None: """Update the name of the playlist. Parameters ---------- name : str The new name of the playlist. """ await PlaylistRow.insert(PlaylistRow(, name=name)).on_conflict( action="DO UPDATE",, values=[] ) await self.update_cache((self.fetch_name, name), (self.exists, True)) await self.invalidate_cache(self.fetch_all)
[docs] @maybe_cached async def fetch_url(self) -> str | None: """Fetch the url of the playlist. Returns ------- str The url of the playlist. """ response = ( await .where( == .first() .output(load_json=True, nested=True) ) return response["url"] if response else PlaylistRow.url.default
[docs] async def update_url(self, url: str) -> None: """Update the url of the playlist. Parameters ---------- url : str The new url of the playlist. """ await PlaylistRow.insert(PlaylistRow(, url=url)).on_conflict( action="DO UPDATE",, values=[PlaylistRow.url] ) await self.update_cache((self.fetch_url, url), (self.exists, True)) await self.invalidate_cache(self.fetch_all)
[docs] @maybe_cached async def fetch_tracks(self) -> list[str | JSON_DICT_TYPE]: """Fetch the tracks of the playlist. Returns ------- list[str] The tracks of the playlist. """ response = ( await PlaylistRow.tracks(TrackRow.encoded,, TrackRow.pluginInfo, load_json=True) ) .where( == .first() .output(load_json=True, nested=True) ) data = response["tracks"] if response else [] return data
[docs] async def update_tracks(self, tracks: list[str | JSON_DICT_TYPE | Track]) -> None: """Update the tracks of the playlist. Parameters ---------- tracks : list[str] The new tracks of the playlist. """ playlist_row = await PlaylistRow.objects().get_or_create( == try: old_tracks = await playlist_row.get_m2m(PlaylistRow.tracks) except ValueError: old_tracks = [] new_tracks = [] # TODO: Optimize this, after is answered or fixed _temp = defaultdict(list) for x in tracks: _temp[type(x)].append(x) for entry_type, entry_list in _temp.items(): if entry_type == str: for track_object in await self.client.decode_tracks(entry_list, raise_on_failure=False): new_tracks.append(await TrackRow.get_or_create(track_object)) elif entry_type == dict: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(from_dict(data_class=Track, data=track_object))) else: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(track_object)) if old_tracks: await playlist_row.remove_m2m(*old_tracks, m2m=PlaylistRow.tracks) if new_tracks: await playlist_row.add_m2m(*new_tracks, m2m=PlaylistRow.tracks) await self.invalidate_cache(self.fetch_tracks, self.fetch_first) await self.update_cache( (self.exists, True), (self.size, len(tracks)), ) await self.invalidate_cache(self.fetch_all)
[docs] @maybe_cached async def size(self) -> int: """Count the tracks of the playlist. Returns ------- int The number of tracks in the playlist. """ tracks = await self.fetch_tracks() return len(tracks) if tracks else 0
[docs] async def add_track(self, tracks: list[str | Track | JSON_DICT_TYPE]) -> None: """Add a track to the playlist. Parameters ---------- tracks : list[str | Track] The tracks to add. """ playlist_row = await PlaylistRow.objects().get_or_create( == new_tracks = [] # TODO: Optimize this, after is answered or fixed _temp = defaultdict(list) for x in tracks: _temp[type(x)].append(x) for entry_type, entry_list in _temp.items(): if entry_type == str: for track_object in await self.client.decode_tracks(entry_list, raise_on_failure=False): new_tracks.append(await TrackRow.get_or_create(track_object)) elif entry_type == dict: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(from_dict(data_class=Track, data=track_object))) else: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(track_object)) if new_tracks: await playlist_row.add_m2m(*new_tracks, m2m=PlaylistRow.tracks) await self.invalidate_cache(self.fetch_tracks, self.fetch_all, self.size, self.fetch_first, self.exists)
[docs] async def bulk_remove_tracks(self, tracks: list[str]) -> None: """Remove disc jockey users from the player. Parameters ---------- tracks : list[str] The track to remove """ if not tracks: return playlist = await PlaylistRow.objects().where( == tracks = await TrackRow.objects().where(TrackRow.encoded.is_in(tracks)) if tracks: await playlist.remove_m2m(*tracks, m2m=PlaylistRow.tracks) await self.invalidate_cache(self.fetch_tracks, self.fetch_all, self.size, self.fetch_first, self.exists)
[docs] async def remove_track(self, track: str) -> None: """Remove a track from the playlist. Parameters ---------- track : str The track to remove """ return await self.bulk_remove_tracks([track])
[docs] async def remove_all_tracks(self) -> None: """Remove all tracks from the playlist.""" playlist = await PlaylistRow.objects().where( == try: tracks = await playlist.get_m2m(PlaylistRow.tracks) except ValueError: tracks = [] if tracks: await playlist.remove_m2m(*tracks, m2m=PlaylistRow.tracks) await self.update_cache((self.fetch_tracks, []), (self.size, 0), (self.exists, True), (self.fetch_first, None)) await self.invalidate_cache(self.fetch_all)
[docs] async def delete(self) -> None: """Delete the playlist from the database""" await PlaylistRow.delete().where( == await self.invalidate_cache()
[docs] async def can_manage(self, bot: DISCORD_BOT_TYPE, requester: -> bool: # noqa """Check if the requester can manage the playlist. Parameters ---------- bot : DISCORD_BOT_TYPE The bot instance. requester : The requester. Returns ------- bool Whether the requester can manage the playlist. """ if in BUNDLED_PLAYLIST_IDS: return False if in ((ids := getattr(bot, "owner_ids")) or ()) or == bot.owner_id: # noqa return True if await self.fetch_scope() == return False return await self.fetch_author() ==
[docs] async def get_scope_name(self, bot: DISCORD_BOT_TYPE, mention: bool = True, guild: discord.Guild = None) -> str: """Get the name of the scope of the playlist. Parameters ---------- bot : DISCORD_BOT_TYPE The bot instance. mention : bool Whether to add a mention if it is mentionable. guild : discord.Guild The guild to get the scope name for. Returns ------- str The name of the scope of the playlist. """ original_scope = await self.fetch_scope() if == original_scope: return _("(Global) {user_name_variable_do_not_translate}").format( user_name_variable_do_not_translate=bot.user.mention if mention else bot.user ) elif guild_ := bot.get_guild(original_scope): if guild_: guild = guild_ return _("(Server) {guild_name_variable_do_not_translate}").format( ) elif guild and (channel := guild.get_channel_or_thread(original_scope)): return _("(Channel) {channel_name_variable_do_not_translate}").format( channel_name_variable_do_not_translate=channel.mention if mention else ) elif ( (guild := guild_ or guild) and (guild and (scope := guild.get_member(original_scope))) # noqa or (scope := bot.get_user(original_scope)) ): return _("(User) {user_name_variable_do_not_translate}").format( user_name_variable_do_not_translate=scope.mention if mention else scope ) else: return _("(Invalid) {scope_name_variable_do_not_translate}").format( scope_name_variable_do_not_translate=original_scope )
[docs] async def get_author_name(self, bot: DISCORD_BOT_TYPE, mention: bool = True) -> str | None: """Get the name of the author of the playlist. Parameters ---------- bot : DISCORD_BOT_TYPE The bot instance. mention : bool Whether to add a mention if it is mentionable. Returns ------- str | None The name of the author of the playlist. """ author = await self.fetch_author() if user := bot.get_user(author): return f"{user.mention}" if mention else f"{user}" return f"{author}"
[docs] async def get_name_formatted(self, with_url: bool = True, escape: bool = True) -> str: """Get the name of the playlist formatted. Parameters ---------- with_url : bool Whether to include the url in the name. escape: bool Whether to markdown escape the response Returns ------- str The formatted name. """ name = SQUARE_BRACKETS.sub("", await self.fetch_name()).strip() if with_url: url = await self.fetch_url() if url and url.startswith("http"): return f"**[{discord.utils.escape_markdown(name) if escape else name}]({url})**" return f"**{discord.utils.escape_markdown(name) if escape else name}**"
[docs] @contextlib.asynccontextmanager async def to_yaml(self, guild: discord.Guild) -> Iterator[tuple[io.BytesIO, str | None]]: """Serialize the playlist to a YAML file. yields a tuple of (io.BytesIO, bool) where the bool is whether the playlist file was compressed using Gzip Parameters ---------- guild : discord.Guild The guild where the yaml will be sent to. Yields ------ tuple[io.BytesIO, str | None] The YAML file and the compression type. """ data = await self.fetch_all() name = data["name"] compression = None with io.BytesIO() as bio: yaml.safe_dump(data, bio, default_flow_style=False, sort_keys=False, encoding="utf-8") LOGGER.debug("SIZE UNCOMPRESSED playlist (%s): %s", name, sys.getsizeof(bio)) if sys.getsizeof(bio) > guild.filesize_limit: with io.BytesIO() as cbio: if BROTLI_ENABLED: compression = "brotli" cbio.write(brotli.compress(yaml.dump(data, encoding="utf-8"))) else: compression = "gzip" with gzip.GzipFile(fileobj=cbio, mode="wb", compresslevel=9) as gzip_file: yaml.safe_dump(data, gzip_file, default_flow_style=False, sort_keys=False, encoding="utf-8") LOGGER.debug("SIZE COMPRESSED playlist [%s] (%s): %s", compression, name, sys.getsizeof(cbio)) yield cbio, compression return yield bio, compression
[docs] async def bulk_update( self, scope: int, name: str, author: int, url: str | None, tracks: list[str | JSON_DICT_TYPE | Track] ) -> None: """Bulk update the playlist.""" defaults = { name, author, PlaylistRow.scope: scope, PlaylistRow.url: url, } playlist_row = await PlaylistRow.objects().get_or_create( ==, defaults) # noinspection PyProtectedMember if not playlist_row._was_created: await PlaylistRow.update(defaults).where( == try: old_tracks = await playlist_row.get_m2m(PlaylistRow.tracks) except ValueError: old_tracks = [] new_tracks = [] # TODO: Optimize this, after is answered or fixed _temp = defaultdict(list) for x in tracks: _temp[type(x)].append(x) for entry_type, entry_list in _temp.items(): if entry_type == str: for track_object in await self.client.decode_tracks(entry_list, raise_on_failure=False): new_tracks.append(await TrackRow.get_or_create(track_object)) elif entry_type == dict: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(from_dict(data_class=Track, data=track_object))) else: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(track_object)) if old_tracks: await playlist_row.remove_m2m(*old_tracks, m2m=PlaylistRow.tracks) if new_tracks: await playlist_row.add_m2m(*new_tracks, m2m=PlaylistRow.tracks) await self.invalidate_cache()
[docs] @classmethod async def from_yaml(cls, context: PyLavContext, scope: int, url: str) -> Playlist: """Deserialize a playlist from a YAML file. Parameters ---------- context : PyLavContext The context. scope : int The scope of the playlist. url : str The url of the playlist. Returns ------- Playlist The playlist. """ try: async with aiohttp.ClientSession(auto_decompress=False, json_serialize=json.dumps) as session: async with session.get(url) as response: data = await if ".gz.pylav" in url: data = gzip.decompress(data) elif ".br.pylav" in url: data = brotli.decompress(data) data = yaml.safe_load(data) except Exception as e: raise InvalidPlaylistException(f"Invalid playlist file - {e}") from e playlist = cls(, ) await playlist.bulk_update( scope=scope, name=data["name"], url=data["url"], tracks=data["tracks"], ) return playlist
[docs] async def fetch_index(self, index: int) -> JSON_DICT_TYPE | None: """Get the track at the index. Parameters ---------- index: int The index of the track Returns ------- str The track at the index """ if READ_CACHING_ENABLED: tracks = await self.fetch_tracks() return tracks[index] if index < len(tracks) else None else: tracks = await self.fetch_tracks() if tracks and len(tracks) > index: return tracks[index]
[docs] @maybe_cached async def fetch_first(self) -> JSON_DICT_TYPE | None: """Get the first track. Returns ------- str The first track """ return await self.fetch_index(0)
[docs] async def fetch_random(self) -> JSON_DICT_TYPE | None: """Get a random track. Returns ------- str A random track """ return await self.fetch_index(random.randint(0, await self.size()))