Source code for pylav.storage.models.query

from __future__ import annotations

import random
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime

from dacite import from_dict

from pylav.constants.config import READ_CACHING_ENABLED
from pylav.helpers.singleton import SingletonCachedByKey
from pylav.nodes.api.responses.track import Track
from pylav.storage.database.cache.decodators import maybe_cached
from pylav.storage.database.cache.model import CachedModel
from pylav.storage.database.tables.queries import QueryRow
from pylav.storage.database.tables.tracks import TrackRow
from pylav.type_hints.dict_typing import JSON_DICT_TYPE


[docs] @dataclass(eq=True, slots=True, unsafe_hash=True, order=True, kw_only=True, frozen=True) class Query(CachedModel, metaclass=SingletonCachedByKey): id: str
[docs] def get_cache_key(self) -> str: return self.id
[docs] @maybe_cached async def exists(self) -> bool: """Check if the config exists. Returns ------- bool Whether the config exists. """ return await QueryRow.exists().where(QueryRow.identifier == self.id)
[docs] async def delete(self) -> None: """Delete the query from the database""" await QueryRow.delete().where(QueryRow.identifier == self.id) await self.invalidate_cache()
[docs] @maybe_cached async def size(self) -> int: """Count the tracks of the playlist. Returns ------- int The number of tracks in the playlist. """ tracks = await self.fetch_tracks() return len(tracks) if tracks else 0
[docs] @maybe_cached async def fetch_tracks(self) -> list[str | JSON_DICT_TYPE]: """Get the tracks of the playlist. Returns ------- list[str] The tracks of the playlist. """ response = ( await QueryRow.select(QueryRow.tracks(TrackRow.encoded, TrackRow.info, TrackRow.pluginInfo, load_json=True)) .where(QueryRow.identifier == self.id) .first() .output(load_json=True, nested=True) ) data = response["tracks"] if response else [] return data
[docs] async def update_tracks(self, tracks: list[str | Track]): """Update the tracks of the playlist. Parameters ---------- tracks: list[str | Track] The tracks of the playlist. """ query_row = await QueryRow.objects().get_or_create(QueryRow.identifier == self.id) try: old_tracks = await query_row.get_m2m(QueryRow.tracks) except ValueError: old_tracks = [] new_tracks = [] # TODO: Optimize this, after https://github.com/piccolo-orm/piccolo/discussions/683 is answered or fixed _temp = defaultdict(list) for x in tracks: _temp[type(x)].append(x) for entry_type, entry_list in _temp.items(): if entry_type == str: for track_object in await self.client.decode_tracks(entry_list, raise_on_failure=False): new_tracks.append(await TrackRow.get_or_create(track_object)) elif entry_type == dict: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(from_dict(data_class=Track, data=track_object))) else: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(track_object)) if old_tracks: await query_row.remove_m2m(*old_tracks, m2m=QueryRow.tracks) if new_tracks: await query_row.add_m2m(*new_tracks, m2m=QueryRow.tracks) await self.invalidate_cache(self.fetch_tracks, self.fetch_first) await self.update_cache( (self.size, len(tracks)), (self.exists, True), )
[docs] @maybe_cached async def fetch_plugin_info(self) -> JSON_DICT_TYPE: """Get the plugin info of the playlist. Returns ------- JSON_DICT_TYPE The plugin info of the playlist. """ response = ( await QueryRow.select(QueryRow.pluginInfo) .where(QueryRow.identifier == self.id) .first() .output(load_json=True, nested=True) ) return response["pluginInfo"] if response else {}
[docs] async def update_plugin_info(self, plugin_info: JSON_DICT_TYPE) -> None: """Update the plugin info of the playlist. Parameters ---------- plugin_info: JSON_DICT_TYPE The plugin info of the playlist. """ await QueryRow.insert(QueryRow(identifier=self.id, pluginInfo=plugin_info)).on_conflict( action="DO UPDATE", target=QueryRow.identifier, values=[QueryRow.pluginInfo] ) await self.update_cache((self.fetch_plugin_info, plugin_info), (self.exists, True))
[docs] @maybe_cached async def fetch_info(self) -> JSON_DICT_TYPE: """Get the info of the playlist. Returns ------- JSON_DICT_TYPE The plugin info of the playlist. """ response = ( await QueryRow.select(QueryRow.info) .where(QueryRow.identifier == self.id) .first() .output(load_json=True, nested=True) ) return response["info"] if response else {}
[docs] async def update_info(self, info: JSON_DICT_TYPE) -> None: """Update the info of the playlist. Parameters ---------- info: JSON_DICT_TYPE The plugin info of the playlist. """ await QueryRow.insert(QueryRow(identifier=self.id, info=info)).on_conflict( action="DO UPDATE", target=QueryRow.identifier, values=[QueryRow.info] ) await self.update_cache((self.fetch_plugin_info, info), (self.exists, True))
[docs] @maybe_cached async def fetch_name(self) -> str: """Get the name of the playlist. Returns ------- str The name of the playlist. """ response = ( await QueryRow.select(QueryRow.name) .where(QueryRow.identifier == self.id) .first() .output(load_json=True, nested=True) ) return response["name"] if response else QueryRow.name.default
[docs] async def update_name(self, name: str) -> None: """Update the name of the playlist. Parameters ---------- name: str The name of the playlist. """ await QueryRow.insert(QueryRow(identifier=self.id, name=name)).on_conflict( action="DO UPDATE", target=QueryRow.identifier, values=[QueryRow.name] ) await self.update_cache((self.fetch_name, name), (self.exists, True))
[docs] @maybe_cached async def fetch_last_updated(self) -> datetime: """Get the last updated time of the playlist. Returns ------- datetime The last updated time of the playlist. """ response = ( await QueryRow.select(QueryRow.last_updated) .where(QueryRow.identifier == self.id) .first() .output(load_json=True, nested=True) ) return response["last_updated"] if response else QueryRow.last_updated.default
[docs] async def update_last_updated(self) -> None: """Update the last updated time of the playlist""" await QueryRow.insert( QueryRow(identifier=self.id, last_updated=QueryRow.last_updated.default.python()) ).on_conflict(action="DO UPDATE", target=QueryRow.identifier, values=[QueryRow.last_updated]) await self.update_cache( (self.fetch_last_updated, QueryRow.last_updated.default.python()), (self.exists, True), )
[docs] async def bulk_update( self, tracks: list[str | Track], name: str, info: JSON_DICT_TYPE | None = None, plugin_info: JSON_DICT_TYPE | None = None, ) -> None: """Bulk update the query. Parameters ---------- tracks: list[str | Track] The tracks of the playlist. name: str The name of the playlist """ defaults = {QueryRow.name: name} if info is not None: defaults[QueryRow.info] = info if plugin_info is not None: defaults[QueryRow.pluginInfo] = plugin_info query_row = await QueryRow.objects().get_or_create(QueryRow.identifier == self.id, defaults) # noinspection PyProtectedMember if not query_row._was_created: await QueryRow.update(defaults).where(QueryRow.identifier == self.id) try: old_tracks = await query_row.get_m2m(QueryRow.tracks) except ValueError: old_tracks = [] new_tracks = [] # TODO: Optimize this, after https://github.com/piccolo-orm/piccolo/discussions/683 is answered or fixed _temp = defaultdict(list) for x in tracks: _temp[type(x)].append(x) for entry_type, entry_list in _temp.items(): if entry_type == str: for track_object in await self.client.decode_tracks(entry_list, raise_on_failure=False): new_tracks.append(await TrackRow.get_or_create(track_object)) elif entry_type == dict: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(from_dict(data_class=Track, data=track_object))) else: for track_object in entry_list: new_tracks.append(await TrackRow.get_or_create(track_object)) if old_tracks: await query_row.remove_m2m(*old_tracks, m2m=QueryRow.tracks) if new_tracks: await query_row.add_m2m(*new_tracks, m2m=QueryRow.tracks) await self.invalidate_cache(self.fetch_tracks, self.fetch_first) await self.update_cache( (self.size, len(tracks)), (self.fetch_name, name), (self.fetch_last_updated, QueryRow.last_updated.default.python()), (self.exists, True), )
[docs] async def fetch_index(self, index: int) -> JSON_DICT_TYPE | None: """Get the track at the index. Parameters ---------- index: int The index of the track Returns ------- str The track at the index """ if READ_CACHING_ENABLED: tracks = await self.fetch_tracks() return tracks[index] if index < len(tracks) else None else: tracks = await self.fetch_tracks() if tracks and len(tracks) > index: return tracks[index]
[docs] @maybe_cached async def fetch_first(self) -> JSON_DICT_TYPE | None: """Get the first track. Returns ------- str The first track """ return await self.fetch_index(0)
[docs] async def fetch_random(self) -> JSON_DICT_TYPE | None: """Get a random track. Returns ------- str A random track """ return await self.fetch_index(random.randint(0, await self.size()))
[docs] async def fetch_bulk( self, info: bool = False, name: bool = False, pluginInfo: bool = False, tracks: bool = False ) -> JSON_DICT_TYPE | None: """Get all tracks. Returns ------- list[str] All tracks """ columns = [QueryRow.identifier] if name: columns.append(QueryRow.name) if info: columns.append(QueryRow.info) if pluginInfo: columns.append(QueryRow.pluginInfo) if tracks: columns.append(QueryRow.tracks(TrackRow.encoded, TrackRow.info, TrackRow.pluginInfo, load_json=True)) response = ( await QueryRow.select(*columns) .where(QueryRow.identifier == self.id) .first() .output(load_json=True, nested=True) ) return response if response else None