from __future__ import annotations
import asyncio
import operator
import os
from functools import partial
from typing import TYPE_CHECKING
import aiohttp
import asyncstdlib
from pylav.compat import json
from pylav.constants.builtin_nodes import BUNDLED_NODES_IDS_HOST_MAPPING, PYLAV_BUNDLED_NODES_SETTINGS
from pylav.constants.config import EXTERNAL_UNMANAGED_NAME, JAVA_EXECUTABLE
from pylav.constants.coordinates import DEFAULT_REGIONS, REGION_TO_COUNTRY_COORDINATE_MAPPING
from pylav.events.node import NodeConnectedEvent, NodeDisconnectedEvent
from pylav.exceptions.client import PyLavNotInitializedException
from pylav.helpers.misc import ExponentialBackoffWithReset
from pylav.logging import getLogger
from pylav.nodes.node import Node
from pylav.nodes.utils import sort_key_nodes
from pylav.players.player import Player
from pylav.storage.models.node.mocked import NodeMock
from pylav.utils.location import get_closest_region_name_and_coordinate
if TYPE_CHECKING:
from pylav.core.client import Client
LOGGER = getLogger("PyLav.NodeManager")
[docs]
class NodeManager:
"""Manages nodes and their connections to the client."""
__slots__ = (
"_client",
"_session",
"_player_queue",
"_unmanaged_external_host",
"_unmanaged_external_password",
"_unmanaged_external_port",
"_unmanaged_external_ssl",
"_nodes",
"_adding_nodes",
"_player_migrate_task",
)
def __init__(
self,
client: Client,
external_host: str = None,
external_password: str = None,
external_port: int = None,
external_ssl: bool = False,
):
self._client = client
self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120), json_serialize=json.dumps)
self._player_queue = set()
self._unmanaged_external_host = external_host
self._unmanaged_external_password = external_password
self._unmanaged_external_port = external_port
self._unmanaged_external_ssl = external_ssl
self._nodes = []
self._adding_nodes = asyncio.Event()
self._player_migrate_task = None
def __iter__(self):
yield from self._nodes
@property
def session(self) -> aiohttp.ClientSession:
"""Returns the aiohttp session used by the client"""
return self._session
@property
def client(self) -> Client:
"""Returns the client"""
return self._client
@property
def nodes(self) -> list[Node]:
"""Returns a list of all nodes"""
return self._nodes
@property
def available_nodes(self) -> list[Node]:
"""Returns a list of available nodes"""
return list(filter(operator.attrgetter("available"), self.nodes))
@property
def managed_nodes(self) -> list[Node]:
"""Returns a list of nodes that are managed by the client"""
return list(filter(operator.attrgetter("managed"), self.nodes))
@property
def search_only_nodes(self) -> list[Node]:
"""Returns a list of nodes that are search only"""
return list(filter(operator.attrgetter("available", "search_only"), self.nodes))
@property
def player_queue(self) -> list[Player]:
"""Returns a list of players that are queued to be played"""
return list(self._player_queue)
@player_queue.setter
def player_queue(self, players: list[Player]) -> None:
"""Sets the player queue"""
self._player_queue = set(players)
@player_queue.deleter
def player_queue(self):
"""Clears the player queue"""
self._player_queue.clear()
[docs]
async def add_node(
self,
*,
host: str,
port: int,
password: str,
unique_identifier: int,
name: str,
resume_timeout: int = 60,
reconnect_attempts: int = -1,
ssl: bool = False,
search_only: bool = False,
disabled_sources: list[str] = None,
managed: bool = False,
yaml: dict | None = None,
extras: dict = None,
temporary: bool = False,
) -> Node:
"""
Adds a node to PyLav's node manager.
Parameters
----------
host: :class:`str`
The address of the Lavalink node.
port: :class:`int`
The port to use for websocket and REST connections.
password: :class:`str`
The password used for authentication.
resume_timeout: Optional[:class:`int`]
How long the node should wait for a connection while disconnected before clearing all players.
Defaults to `60`.
name: :class:`str`
An identifier for the node that will show in logs. Defaults to `None`.
reconnect_attempts: Optional[:class:`int`]
The amount of times connection with the node will be reattempted before giving up.
Set to `-1` for infinite. Defaults to `3`.
ssl: Optional[:class:`bool`]
Whether to use a ssl connection. Defaults to `False`.
search_only: :class:`bool`
Whether the node is search only. Defaults to `False`.
unique_identifier: Optional[:class:`str`]
A unique identifier for the node. Defaults to `None`.
disabled_sources: Optional[:class:`list`[:class:`str`]]
A list of sources to disable. Defaults to `None`.
managed: Optional[:class:`bool`]
Whether the node is managed by the client. Defaults to `False`.
yaml: Optional[:class:`dict`]
A dictionary of node settings. Defaults to `None`.
extras: Optional[:class:`dict`]
A dictionary of extra settings. Defaults to `{}`.
temporary: :class:`bool`
Whether the node is temporary. Defaults to `False`.
Temporary nodes are not added to the db.
Returns
-------
:class:`Node`
The node that was added.
"""
node = Node(
manager=self,
host=host,
port=port,
password=password,
resume_timeout=resume_timeout,
name=name,
reconnect_attempts=reconnect_attempts,
ssl=ssl,
search_only=search_only,
unique_identifier=unique_identifier,
disabled_sources=disabled_sources,
managed=managed,
extras=extras or {},
temporary=temporary,
)
self._nodes.append(node)
# noinspection PyProtectedMember
node._logger.info("Successfully added to Node Manager")
# noinspection PyProtectedMember
node._logger.verbose("Successfully added to Node Manager -- %r", node)
if temporary:
yaml = yaml or {"server": {}, "lavalink": {"server": {}}}
yaml["server"]["address"] = host # type: ignore
yaml["server"]["port"] = port # type: ignore
yaml["lavalink"]["server"]["password"] = password
data = {
"name": name,
"ssl": ssl,
"resume_timeout": resume_timeout,
"reconnect_attempts": reconnect_attempts,
"search_only": search_only,
"managed": managed,
"extras": extras or {},
"yaml": yaml,
"disabled_sources": disabled_sources,
}
node._config = NodeMock(id=unique_identifier, data=data)
else:
node._config = await self.client.node_db_manager.update_node(
host=host,
port=port,
password=password,
resume_timeout=resume_timeout,
name=name,
reconnect_attempts=reconnect_attempts,
ssl=ssl,
search_only=search_only,
unique_identifier=unique_identifier,
disabled_sources=disabled_sources,
managed=managed,
yaml=yaml,
extras=extras or {},
)
return node
[docs]
async def remove_node(self, node: Node) -> None:
"""
Removes a node.
Parameters
----------
node: :class:`Node`
The node to remove from the list.
"""
await node.close()
self.nodes.remove(node)
# noinspection PyProtectedMember
node._logger.info("Successfully removed Node")
# noinspection PyProtectedMember
node._logger.verbose("Successfully removed Node -- %r", node)
if (
node.identifier
and not node.managed
and node.identifier not in BUNDLED_NODES_IDS_HOST_MAPPING
and node.identifier != 31415
):
await self.client.node_db_manager.delete(node.identifier)
# noinspection PyProtectedMember
node._logger.debug("Successfully deleted Node from database")
[docs]
def get_region(self, endpoint: str | None) -> str | None:
"""
Returns a region from a Discord voice server address.
Parameters
----------
endpoint: :class:`str`
The address of the Discord voice server.
Returns
-------
Optional[:class:`str`]
"""
if not endpoint:
return None
endpoint = endpoint.replace("vip-", "")
for key in DEFAULT_REGIONS:
nodes = [n for n in self.available_nodes if n.region == key]
if not nodes:
continue
if endpoint.startswith(key):
return key
return None
[docs]
def get_closest_node(self, region: str) -> Node:
"""
Returns the closest node to a given region.
Parameters
----------
region: :class:`str`
The region to use.
Returns
-------
:class:`Node`
"""
return min(self.available_nodes, key=lambda n: n.region_distance(region))
[docs]
async def find_best_node(
self,
region: str = None,
not_region: str = None,
feature: str = None,
already_attempted_regions: set[str] = None,
coordinates: tuple[float, float] = None,
wait: bool = False,
attempt: int = 0,
backoff: ExponentialBackoffWithReset = None,
) -> Node | None:
"""Finds the best (least used) node in the given region, if applicable.
Parameters
----------
region: :class:`str`
The region to use.
not_region: :class:`str`
The region to exclude.
feature: :class:`str`
The feature required.
already_attempted_regions: :class:`set`[:class:`str`]
A set of regions that have already been attempted.
coordinates: :class:`tuple`[:class:`float`, :class:`float`]
The coordinates to use.
wait: :class:`bool`
Whether to wait for a node to become available.
attempt: :class:`int`
The current attempt number.
backoff: :class:`ExponentialBackoffWithReset`
The backoff to use.
Returns
-------
Optional[:class:`Node`]
"""
if backoff is None:
backoff = ExponentialBackoffWithReset()
delay = 1
else:
delay = backoff.delay()
already_attempted_regions = already_attempted_regions or set()
if feature:
nodes = [n for n in self.available_nodes if n.has_capability(feature)]
else:
nodes = self.available_nodes
if coordinates is None:
if region and region in REGION_TO_COUNTRY_COORDINATE_MAPPING:
coordinates = REGION_TO_COUNTRY_COORDINATE_MAPPING[region]
else:
coordinates = (0, 0)
if region and not_region:
nodes = await self._get_nodes_by_region_with_exclusion(
already_attempted_regions, coordinates, nodes, not_region, region
)
elif region:
nodes = await self._get_nodes_by_region_only(already_attempted_regions, coordinates, nodes, region)
else:
nodes = [n for n in nodes if n.region != not_region and n.region not in already_attempted_regions]
if not nodes:
nodes = await self._get_fall_back_nodes(already_attempted_regions, feature, nodes)
node = await asyncstdlib.min(nodes, key=partial(sort_key_nodes, region=region), default=None) if nodes else None
if node is None and wait:
await asyncio.sleep(delay)
return await self.find_best_node(
region=region,
not_region=not_region,
feature=feature,
already_attempted_regions=already_attempted_regions,
coordinates=coordinates,
wait=wait,
backoff=backoff,
attempt=attempt + 1,
)
return node
async def _get_fall_back_nodes(self, already_attempted_regions, feature, nodes):
if feature:
nodes = [
n
for n in self.available_nodes
if n.has_capability(feature) and n.region not in already_attempted_regions
]
else:
nodes = self.available_nodes
return nodes
async def _get_nodes_by_region_only(self, already_attempted_regions, coordinates, nodes, region):
available_regions = {n.region for n in self.available_nodes if n.region not in already_attempted_regions}
closest_region, __ = await get_closest_region_name_and_coordinate(*coordinates, region_pool=available_regions)
nodes = [
n for n in nodes if (n.region in [region, closest_region]) and n.region not in already_attempted_regions
]
return nodes
async def _get_nodes_by_region_with_exclusion(
self, already_attempted_regions, coordinates, nodes, not_region, region
):
available_regions = {n.region for n in self.available_nodes if n.region not in already_attempted_regions}
closest_region, __ = await get_closest_region_name_and_coordinate(*coordinates, region_pool=available_regions)
nodes = [
n
for n in nodes
if (n.region in [region, closest_region])
and n.region != not_region
and n.region not in already_attempted_regions
]
return nodes
[docs]
def get_node_by_id(self, unique_identifier: int) -> Node | None:
"""
Returns a node by its unique identifier.
Parameters
----------
unique_identifier: :class:`int`
The unique identifier of the node.
Returns
-------
Optional[:class:`Node`]
"""
return next((n for n in self.nodes if n.identifier == unique_identifier), None)
[docs]
async def node_connect(self, node: Node) -> None:
"""
Called when a node is connected from Lavalink.
Parameters
----------
node: :class:`Node`
The node that has just connected.
"""
# noinspection PyProtectedMember
node._logger.debug("Successfully established connection")
del node.down_votes
self._player_migrate_task = asyncio.create_task(self._player_change_node_task(node))
self.client.dispatch_event(NodeConnectedEvent(node))
async def _player_change_node_task(self, node):
for player in iter(self.player_queue):
await player.change_node(node, forced=True)
# noinspection PyProtectedMember
node._logger.debug("Successfully moved %s", player.guild.id)
# noinspection PyProtectedMember
if self.client._connect_back:
# noinspection PyProtectedMember
for player in iter(node._original_players):
await player.change_node(node, forced=True)
player._original_node = None
del self.player_queue
self._player_migrate_task = None
[docs]
async def node_disconnect(self, node: Node, code: int, reason: str) -> None:
"""
Called when a node is disconnected from Lavalink.
Parameters
----------
node: :class:`Node`
The node that has just connected.
code: :class:`int`
The code for why the node was disconnected.
reason: :class:`str`
The reason why the node was disconnected.
"""
if self.client.is_shutting_down:
return
# noinspection PyProtectedMember
node._logger.warning("Disconnected with code %s and reason %s", code, reason)
# noinspection PyProtectedMember
node._logger.verbose(
"Disconnected with code %s and reason %s -- %r",
code,
reason,
node,
)
self.client.dispatch_event(NodeDisconnectedEvent(node, code, reason))
best_node = await self.find_best_node(region=node.region)
if not best_node or not best_node.available:
self.player_queue = self.player_queue + node.players
LOGGER.error("Unable to move players, no available nodes! Waiting for a node to become available")
return
for player in iter(node.players):
await player.change_node(best_node, forced=True)
# noinspection PyProtectedMember
if self.client._connect_back:
player._original_node = node
[docs]
async def close(self) -> None:
"""Disconnects all nodes and closes the session."""
if self._player_migrate_task is not None:
self._player_migrate_task.cancel()
await self.session.close()
for node in iter(self.nodes):
await node.close()
[docs]
async def connect_to_all_nodes(self) -> None:
"""Connects to all nodes."""
nodes_list = []
for node in iter(await self.client.node_db_manager.get_all_unmanaged_nodes()):
await self._process_single_unmanaged_node_connection(node, nodes_list)
await self._process_envvar_node(nodes_list)
# noinspection PyProtectedMember
config_data = self.client._lib_config_manager.get_config()
all_data = await config_data.fetch_all()
if all_data["java_path"] != JAVA_EXECUTABLE and os.path.exists(JAVA_EXECUTABLE):
await config_data.update_java_path(JAVA_EXECUTABLE)
tasks = [asyncio.create_task(n.wait_until_ready()) for n in nodes_list]
if not tasks:
if await self.client.managed_node_is_enabled():
self._adding_nodes.set()
return True
LOGGER.warning("No nodes found, please add some nodes")
raise PyLavNotInitializedException("Failed to connect to any nodes")
done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
for task in pending:
task.cancel()
for result in done:
result.result()
len_nodes = sum(1 for node in nodes_list if node.available)
if len_nodes == 0:
raise PyLavNotInitializedException("No nodes are available")
if not self._adding_nodes.is_set():
self._adding_nodes.set()
return True
async def _process_bundled_node_lava_link(self, nodes_list):
if all(True for n in iter(nodes_list) if n.host != "lava.link"):
nodes_list.append(
await self.add_node(
password=f"PyLav/{self.client.lib_version}",
**PYLAV_BUNDLED_NODES_SETTINGS["lava.link"],
)
)
else:
LOGGER.debug(
"%s already added to connection pool - skipping duplicated connection",
PYLAV_BUNDLED_NODES_SETTINGS["lava.link"]["name"],
)
async def _process_bundled_node_ny(self, nodes_list):
if all(True for n in iter(nodes_list) if n.host != "ll-us-ny.draper.wtf") and not self.get_node_by_id(
PYLAV_BUNDLED_NODES_SETTINGS["ll-us-ny.draper.wtf"]["unique_identifier"]
):
base_settings = PYLAV_BUNDLED_NODES_SETTINGS["ll-us-ny.draper.wtf"]
nodes_list.append(await self.add_node(**base_settings))
else:
LOGGER.debug(
"%s already added to connection pool - skipping duplicated connection",
PYLAV_BUNDLED_NODES_SETTINGS["ll-us-ny.draper.wtf"]["name"],
)
async def _process_bundled_node_london(self, nodes_list):
if all(True for n in iter(nodes_list) if n.host != "ll-gb.draper.wtf") and not self.get_node_by_id(
PYLAV_BUNDLED_NODES_SETTINGS["ll-gb.draper.wtf"]["unique_identifier"]
):
base_settings = PYLAV_BUNDLED_NODES_SETTINGS["ll-gb.draper.wtf"]
base_settings["host"] = "ll-gb.draper.wtf"
nodes_list.append(await self.add_node(**base_settings))
else:
LOGGER.debug(
"%s already added to connection pool - skipping duplicated connection",
PYLAV_BUNDLED_NODES_SETTINGS["ll-gb.draper.wtf"]["name"],
)
async def _process_envvar_node(self, nodes_list):
if self._unmanaged_external_host and self._unmanaged_external_password:
if all(True for n in nodes_list if n.host != self._unmanaged_external_host):
if self._unmanaged_external_host in PYLAV_BUNDLED_NODES_SETTINGS:
base_settings = PYLAV_BUNDLED_NODES_SETTINGS[self._unmanaged_external_host]
else:
base_settings = {
"port": self._unmanaged_external_port or (443 if self._unmanaged_external_ssl else 80),
"ssl": self._unmanaged_external_ssl,
"password": self._unmanaged_external_password,
"resume_timeout": 600,
"reconnect_attempts": -1,
"search_only": False,
"managed": False,
"disabled_sources": [],
"host": self._unmanaged_external_host,
"unique_identifier": 31415,
"name": EXTERNAL_UNMANAGED_NAME,
"temporary": True,
}
nodes_list.append(await self.add_node(**base_settings))
else:
LOGGER.warning(
"%s already added to connection pool - skipping duplicated connection - (%s:%s)",
EXTERNAL_UNMANAGED_NAME,
self._unmanaged_external_host,
self._unmanaged_external_port,
)
async def _process_single_unmanaged_node_connection(self, node, nodes_list):
if node.id == self.client.bot.user.id:
LOGGER.debug("Skipping node %s as it is the managed node", node.id)
return
node_data = await node.fetch_all()
try:
if node in nodes_list:
LOGGER.warning(
"%s Node already added to connection pool - skipping duplicated connection - (%s:%s)",
node_data["name"],
node_data["yaml"]["server"]["address"],
node_data["yaml"]["server"]["port"],
)
return
if node_data["yaml"]["server"]["address"] in PYLAV_BUNDLED_NODES_SETTINGS:
connection_arguments = PYLAV_BUNDLED_NODES_SETTINGS[node_data["yaml"]["server"]["address"]]
else:
connection_arguments = await node.get_connection_args()
nodes_list.append(await self.add_node(**connection_arguments))
except (ValueError, KeyError) as exc:
LOGGER.warning("Invalid node, skipping ... id: %s - Original error: %s", node.id, exc)
[docs]
async def wait_until_ready(self, timeout: float | None = None):
"""Wait until all nodes are ready."""
await asyncio.wait_for(self._adding_nodes.wait(), timeout=timeout)