Source code for pylav.utils.localtracks

from __future__ import annotations

import asyncio
import hashlib
import pathlib
from typing import TYPE_CHECKING

import aiopath
from discord.utils import utcnow
from expiringdict import ExpiringDict
from watchfiles import Change, awatch

from pylav.constants.config import POSTGRES_CONNECTIONS
from pylav.logging import getLogger
from pylav.nodes.api.responses.track import Track
from pylav.players.query.local_files import ALL_EXTENSIONS
from pylav.players.query.obj import Query

if TYPE_CHECKING:
    from pylav.core.client import Client

LOGGER = getLogger("PyLav.LocalTrackCache")


[docs] class LocalTrackCache: """A cache for local tracks.""" __slots__ = ( "__pylav", "__ready", "__monitor", "__shutdown", "__query_cache", "__track_cache", "__root_folder", "__query_lock", "__track_lock", "__path_to_query_cache", "__counter", ) def __init__(self, client: Client, root: str | pathlib.Path | aiopath.Path) -> None: self.__shutdown = False self.__pylav = client self.__query_cache: dict[str, Query] = {} self.__track_cache: dict[str, Track] = {} self.__path_to_query_cache: dict[str, Query] = {} self.__query_lock = asyncio.Lock() self.__track_lock = asyncio.Lock() self.__root_folder = pathlib.Path(root) self.__ready = asyncio.Event() self.__monitor = asyncio.create_task(self.file_watcher()) self.__counter = ExpiringDict(max_len=float("inf"), max_age_seconds=5) def __bool__(self) -> bool: return not self.__shutdown @property def hexdigest_to_query(self) -> dict[str, Query]: """The hexdigest to query cache.""" return self.__query_cache @property def path_to_query(self) -> dict[str, Query]: """The path to query cache.""" return self.__path_to_query_cache @property def path_to_track(self) -> dict[str, Track]: """The path to track cache.""" return self.__track_cache @property def root_folder(self) -> pathlib.Path: """The root folder of the local track cache.""" return self.__root_folder @property def is_ready(self) -> bool: """Whether the local track cache is ready.""" return self.__ready.is_set()
[docs] async def initialize(self): """Initialize the local track cache.""" await self.__pylav.wait_until_ready() await self.update() self.__ready.set()
[docs] async def shutdown(self): """Shutdown the local track cache.""" self.__shutdown = True self.__monitor.cancel() self.__ready.clear() await self.wipe_cache()
async def _add_to_query_cache(self, query: Query, path: str) -> None: if self.__shutdown: return async with self.__query_lock: self.__query_cache[hashlib.md5(f"{query._query}".encode()).hexdigest()] = query self.__path_to_query_cache[path] = query async def _remove_from_query_cache(self, query: Query, path: str) -> None: if self.__shutdown: return async with self.__query_lock: self.__query_cache.pop(hashlib.md5(f"{query._query}".encode()).hexdigest(), None) self.__path_to_query_cache.pop(path, None) async def _add_to_track_cache(self, track: Track, path: pathlib.Path) -> None: if self.__shutdown: return async with self.__track_lock: self.__track_cache[f"{path}"] = track async def _remove_from_track_cache(self, path: pathlib.Path) -> None: if self.__shutdown: return async with self.__track_lock: self.__track_cache.pop(f"{path}", None)
[docs] async def wipe_cache(self) -> None: """Wipe the local track cache.""" await self.__track_lock.acquire() await self.__query_lock.acquire() self.__track_cache.clear() self.__query_cache.clear() self.__path_to_query_cache.clear() self.__track_lock.release() self.__query_lock.release()
[docs] async def file_watcher(self): """A file watcher for the local track cache.""" await self.__ready.wait() async for changes in awatch(self.root_folder, recursive=True): if self.__shutdown: return await self._process_changes(changes)
async def _process_changes(self, changes: set[tuple[Change, str]]) -> None: for change, path in changes: path_obj = pathlib.Path(path) if (not path_obj.is_dir()) and path_obj.suffix.lower() not in ALL_EXTENSIONS: continue match change: case Change.added: await self._process_added(path, path_obj) LOGGER.trace(f"Added {path}") case Change.modified: await self._process_modified(path, path_obj, modified=True) LOGGER.trace(f"Modified {path}") case Change.deleted: await self._process_deleted(path, path_obj) LOGGER.trace(f"Deleted {path}") async def _process_added(self, path: str, path_obj: pathlib.Path, modified: bool = False) -> None: if self.__shutdown: return query = await Query.from_string(path_obj) if path_obj.is_dir(): await self._add_to_query_cache(query, path) return self.__counter["added"] = self.__counter.get("added", default=0) + 1 if self.__counter["added"] % 3 == 10: self.__counter["added"] = 0 should_sleep = True else: should_sleep = False track = await self.__pylav.search_query(query, bypass_cache=modified, sleep=should_sleep) if track.loadType in {"track", "playlist", "search"}: await self._add_to_query_cache(query, path) match track.loadType: case "track": await self._add_to_track_cache(track.data, path_obj) case "playlist": for track in track.data.tracks: await self._add_to_track_cache(track, path_obj) case "search": for track in track.data: await self._add_to_track_cache(track, path_obj) async def _process_modified(self, path: str, path_obj: pathlib.Path, modified: bool = True) -> None: if self.__shutdown: return query = await Query.from_string(path_obj) await self._remove_from_query_cache(query, path) await self._remove_from_track_cache(path_obj) if path_obj.is_dir(): await self._add_to_query_cache(query, path) return track = await self.__pylav.search_query(query, bypass_cache=modified, sleep=True) if track.loadType in {"track", "playlist", "search"}: await self._add_to_query_cache(query, path) match track.loadType: case "track": await self._add_to_track_cache(track.data, path_obj) case "playlist": for track in track.data.tracks: await self._add_to_track_cache(track, path_obj) case "search": for track in track.data: await self._add_to_track_cache(track, path_obj) async def _process_deleted(self, path: str, path_obj: pathlib.Path) -> None: if self.__shutdown: return query = await Query.from_string(path_obj) await self._remove_from_query_cache(query, path) await self._remove_from_track_cache(path_obj)
[docs] async def update(self) -> None: """Update the local track cache.""" if self.__shutdown: return await self.__pylav.wait_until_ready() chunk_size = min(POSTGRES_CONNECTIONS, 50) LOGGER.debug("Updating cache") start = utcnow() chunk = [] for entry in self.__root_folder.rglob("*"): if (not entry.is_dir()) and entry.suffix.lower() not in ALL_EXTENSIONS: continue chunk.append(entry) if len(chunk) == chunk_size: await asyncio.gather(*[self._process_added(f"{entry}", entry) for entry in chunk]) chunk = [] if chunk: await asyncio.gather(*[self._process_added(f"{entry}", entry) for entry in chunk]) LOGGER.debug("Finished updating cache in %s", utcnow() - start)