Source code for pylav.storage.controllers.equalizers

from __future__ import annotations

from collections.abc import AsyncIterator
from typing import TYPE_CHECKING

import discord
from piccolo.columns import Float

from pylav.exceptions.database import EntryNotFoundException
from pylav.logging import getLogger
from pylav.storage.database.tables.equalizer import EqualizerRow
from pylav.storage.models import equilizer
from pylav.type_hints.bot import DISCORD_BOT_TYPE

LOGGER = getLogger("PyLav.Database.Controller.Equalizer")

if TYPE_CHECKING:
    from pylav.core.client import Client


[docs] class EqualizerController: __slots__ = ("_client",) def __init__(self, client: Client) -> None: self._client = client @property def client(self) -> Client: return self._client
[docs] @staticmethod async def get_equalizer_by_name(equalizer_name: str, limit: int = None) -> list[equilizer.Equalizer]: if limit is None: equalizers = ( await EqualizerRow.select() .where(EqualizerRow.name.ilike(f"%{equalizer_name.lower()}%")) .output(load_json=True, nested=True) ) else: equalizers = ( await EqualizerRow.select().where(EqualizerRow.name.ilike(f"%{equalizer_name.lower()}%")).limit(limit) ).output(load_json=True, nested=True) if not equalizers: raise EntryNotFoundException(f"Equalizer with name {equalizer_name} not found") return [equilizer.Equalizer(**equalizer) for equalizer in equalizers]
[docs] @staticmethod async def get_equalizer_by_id(identifier: int | str) -> equilizer.Equalizer: try: identifier = int(identifier) except ValueError as e: raise EntryNotFoundException(f"Equalizer with id {identifier} not found") from e equalizer = ( await EqualizerRow.select().where(EqualizerRow.id == identifier).first().output(load_json=True, nested=True) ) if not equalizer: raise EntryNotFoundException(f"Equalizer with ID {identifier} not found") return equilizer.Equalizer(**equalizer)
[docs] async def get_equalizer_by_name_or_id( self, equalizer_name_or_id: int | str, limit: int = None ) -> list[equilizer.Equalizer]: try: return [await self.get_equalizer_by_id(equalizer_name_or_id)] except EntryNotFoundException: return await self.get_equalizer_by_name(equalizer_name_or_id, limit=limit)
[docs] @staticmethod async def get_equalizers_by_author(author: int) -> list[equilizer.Equalizer]: equalizers = ( await EqualizerRow.select().where(EqualizerRow.author == author).output(load_json=True, nested=True) ) if not equalizers: raise EntryNotFoundException(f"Equalizer with author {author} not found") return [equilizer.Equalizer(**equalizer) for equalizer in equalizers]
[docs] @staticmethod async def get_equalizers_by_scope(scope: int) -> list[equilizer.Equalizer]: equalizers = await EqualizerRow.select().where(EqualizerRow.scope == scope).output(load_json=True, nested=True) if not equalizers: raise EntryNotFoundException(f"Equalizer with scope {scope} not found") return [equilizer.Equalizer(**equalizer) for equalizer in equalizers]
[docs] @staticmethod async def get_all_equalizers() -> AsyncIterator[equilizer.Equalizer]: for entry in await EqualizerRow.select().output(load_json=True, nested=True): yield equilizer.Equalizer(**entry)
@staticmethod def _get_equalizer_band_defaults(*args: tuple[Float, float | None]) -> dict[Float, float]: null_values = {None, 0.0} return {band: value for band, value in args if value not in null_values}
[docs] @staticmethod async def create_or_update_equalizer( identifier: int, scope: int, author: int, name: str, description: str | None = None, band_25: float | None = None, band_40: float | None = None, band_63: float | None = None, band_100: float | None = None, band_160: float | None = None, band_250: float | None = None, band_400: float | None = None, band_630: float | None = None, band_1000: float | None = None, band_1600: float | None = None, band_2500: float | None = None, band_4000: float | None = None, band_6300: float | None = None, band_10000: float | None = None, band_16000: float | None = None, ) -> equilizer.Equalizer: values = { EqualizerRow.scope: scope, EqualizerRow.author: author, EqualizerRow.name: name, EqualizerRow.description: description, } values |= EqualizerController._get_equalizer_band_defaults( (EqualizerRow.band_25, band_25), (EqualizerRow.band_40, band_40), (EqualizerRow.band_63, band_63), (EqualizerRow.band_100, band_100), (EqualizerRow.band_160, band_160), (EqualizerRow.band_250, band_250), (EqualizerRow.band_400, band_400), (EqualizerRow.band_630, band_630), (EqualizerRow.band_1000, band_1000), (EqualizerRow.band_1600, band_1600), (EqualizerRow.band_2500, band_2500), (EqualizerRow.band_4000, band_4000), (EqualizerRow.band_6300, band_6300), (EqualizerRow.band_10000, band_10000), (EqualizerRow.band_16000, band_16000), ) equalizer = ( await EqualizerRow.objects() .output(load_json=True) .get_or_create(EqualizerRow.id == identifier, defaults=values) ) # noinspection PyProtectedMember if not equalizer._was_created: await EqualizerRow.update(values).where(EqualizerRow.id == identifier) return equilizer.Equalizer(**equalizer.to_dict())
[docs] @staticmethod async def delete_equalizer(equalizer_id: int) -> None: await EqualizerRow.delete().where(EqualizerRow.id == equalizer_id)
[docs] @staticmethod async def get_all_equalizers_by_author(author: int) -> AsyncIterator[equilizer.Equalizer]: for entry in ( await EqualizerRow.select().where(EqualizerRow.author == author).output(load_json=True, nested=True) ): yield equilizer.Equalizer(**entry)
[docs] @staticmethod async def get_all_equalizers_by_scope(scope: int) -> AsyncIterator[equilizer.Equalizer]: for entry in await EqualizerRow.select().where(EqualizerRow.scope == scope).output(load_json=True, nested=True): yield equilizer.Equalizer(**entry)
[docs] @staticmethod async def get_all_equalizers_by_scope_and_author(scope: int, author: int) -> AsyncIterator[equilizer.Equalizer]: for entry in ( await EqualizerRow.select() .where( EqualizerRow.scope == scope, EqualizerRow.author == author, ) .output(load_json=True, nested=True) ): yield equilizer.Equalizer(**entry)
[docs] async def get_global_equalizers(self) -> AsyncIterator[equilizer.Equalizer]: for entry in ( await EqualizerRow.select() .where(EqualizerRow.scope == self._client.bot.user.id) # type: ignore .output(load_json=True, nested=True) ): yield equilizer.Equalizer(**entry)
[docs] async def create_or_update_global_equalizer( self, identifier: int, author: int, name: str, description: str | None = None, band_25: float | None = None, band_40: float | None = None, band_63: float | None = None, band_100: float | None = None, band_160: float | None = None, band_250: float | None = None, band_400: float | None = None, band_630: float | None = None, band_1000: float | None = None, band_1600: float | None = None, band_2500: float | None = None, band_4000: float | None = None, band_6300: float | None = None, band_10000: float | None = None, band_16000: float | None = None, ) -> equilizer.Equalizer: return await self.create_or_update_equalizer( identifier=identifier, scope=self._client.bot.user.id, author=author, name=name, description=description, band_25=band_25, band_40=band_40, band_63=band_63, band_100=band_100, band_160=band_160, band_250=band_250, band_400=band_400, band_630=band_630, band_1000=band_1000, band_1600=band_1600, band_2500=band_2500, band_4000=band_4000, band_6300=band_6300, band_10000=band_10000, band_16000=band_16000, )
[docs] async def create_or_update_user_equalizer( self, user_id: int, author: int, name: str, description: str | None = None, band_25: float | None = None, band_40: float | None = None, band_63: float | None = None, band_100: float | None = None, band_160: float | None = None, band_250: float | None = None, band_400: float | None = None, band_630: float | None = None, band_1000: float | None = None, band_1600: float | None = None, band_2500: float | None = None, band_4000: float | None = None, band_6300: float | None = None, band_10000: float | None = None, band_16000: float | None = None, ) -> equilizer.Equalizer: return await self.create_or_update_equalizer( identifier=user_id, scope=author, author=author, name=name, description=description, band_25=band_25, band_40=band_40, band_63=band_63, band_100=band_100, band_160=band_160, band_250=band_250, band_400=band_400, band_630=band_630, band_1000=band_1000, band_1600=band_1600, band_2500=band_2500, band_4000=band_4000, band_6300=band_6300, band_10000=band_10000, band_16000=band_16000, )
[docs] async def create_or_update_channel_equalizer( self, text_channel: discord.abc.MessageableChannel, identifier: int, author: int, name: str, description: str | None = None, band_25: float | None = None, band_40: float | None = None, band_63: float | None = None, band_100: float | None = None, band_160: float | None = None, band_250: float | None = None, band_400: float | None = None, band_630: float | None = None, band_1000: float | None = None, band_1600: float | None = None, band_2500: float | None = None, band_4000: float | None = None, band_6300: float | None = None, band_10000: float | None = None, band_16000: float | None = None, ) -> equilizer.Equalizer: return await self.create_or_update_equalizer( identifier=identifier, scope=text_channel.id, author=author, name=name, description=description, band_25=band_25, band_40=band_40, band_63=band_63, band_100=band_100, band_160=band_160, band_250=band_250, band_400=band_400, band_630=band_630, band_1000=band_1000, band_1600=band_1600, band_2500=band_2500, band_4000=band_4000, band_6300=band_6300, band_10000=band_10000, band_16000=band_16000, )
[docs] async def create_or_update_guild_equalizer( self, guild: discord.Guild, identifier: int, author: int, name: str, description: str | None = None, band_25: float | None = None, band_40: float | None = None, band_63: float | None = None, band_100: float | None = None, band_160: float | None = None, band_250: float | None = None, band_400: float | None = None, band_630: float | None = None, band_1000: float | None = None, band_1600: float | None = None, band_2500: float | None = None, band_4000: float | None = None, band_6300: float | None = None, band_10000: float | None = None, band_16000: float | None = None, ) -> equilizer.Equalizer: return await self.create_or_update_equalizer( identifier=identifier, scope=guild.id, author=author, name=name, description=description, band_25=band_25, band_40=band_40, band_63=band_63, band_100=band_100, band_160=band_160, band_250=band_250, band_400=band_400, band_630=band_630, band_1000=band_1000, band_1600=band_1600, band_2500=band_2500, band_4000=band_4000, band_6300=band_6300, band_10000=band_10000, band_16000=band_16000, )
[docs] async def create_or_update_vc_equalizer( self, voice_channel: discord.channel.VocalGuildChannel, identifier: int, author: int, name: str, description: str | None = None, band_25: float | None = None, band_40: float | None = None, band_63: float | None = None, band_100: float | None = None, band_160: float | None = None, band_250: float | None = None, band_400: float | None = None, band_630: float | None = None, band_1000: float | None = None, band_1600: float | None = None, band_2500: float | None = None, band_4000: float | None = None, band_6300: float | None = None, band_10000: float | None = None, band_16000: float | None = None, ) -> equilizer.Equalizer: return await self.create_or_update_equalizer( identifier=identifier, scope=voice_channel.id, author=author, name=name, description=description, band_25=band_25, band_40=band_40, band_63=band_63, band_100=band_100, band_160=band_160, band_250=band_250, band_400=band_400, band_630=band_630, band_1000=band_1000, band_1600=band_1600, band_2500=band_2500, band_4000=band_4000, band_6300=band_6300, band_10000=band_10000, band_16000=band_16000, )
[docs] async def get_all_for_user( self, requester: int, *, voice_channel: discord.channel.VocalGuildChannel = None, guild: discord.Guild = None, channel: discord.abc.MessageableChannel = None, ) -> tuple[ list[equilizer.Equalizer], list[equilizer.Equalizer], list[equilizer.Equalizer], list[equilizer.Equalizer], list[equilizer.Equalizer], ]: """ Gets all equalizers a user has access to in a given context. Globals, User specific, Guild specific, Channel specific, VC specific. """ global_equalizers = [p async for p in self.get_all_equalizers_by_scope(scope=self._client.bot.user.id)] user_equalizers = [p async for p in self.get_all_equalizers_by_scope(scope=requester)] vc_equalizers = [] guild_equalizers = [] channel_equalizers = [] if voice_channel is not None: vc_equalizers = [p async for p in self.get_all_equalizers_by_scope(scope=voice_channel.id)] if guild is not None: guild_equalizers = [p async for p in self.get_all_equalizers_by_scope(scope=guild.id)] if channel is not None: channel_equalizers = [p async for p in self.get_all_equalizers_by_scope(scope=channel.id)] return global_equalizers, user_equalizers, guild_equalizers, channel_equalizers, vc_equalizers
[docs] async def get_manageable_equalizers( self, requester: discord.abc.User, bot: DISCORD_BOT_TYPE, *, name_or_id: str | None = None ) -> list[equilizer.Equalizer]: if name_or_id: try: equalizers = await self.get_equalizer_by_name_or_id(name_or_id) except EntryNotFoundException: equalizers = [] else: try: equalizers = [p async for p in self.get_all_equalizers()] except EntryNotFoundException: equalizers = [] returning_list = [] if equalizers: for equalizer in equalizers: if await equalizer.can_manage(requester=requester, bot=bot): returning_list.append(equalizer) return returning_list
[docs] @staticmethod async def count() -> int: """Returns the number of equalizers in the database.""" return await EqualizerRow.count()