Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include a Snapserver in the snapcast provider #1150

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions music_assistant/common/helpers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,26 +151,27 @@ def _get_ip():
return await asyncio.to_thread(_get_ip)


async def select_free_port(range_start: int, range_end: int) -> int:
"""Automatically find available port within range."""
async def is_port_in_use(port: int) -> bool:
"""Check if port is in use."""

def is_port_in_use(port: int) -> bool:
"""Check if port is in use."""
def _is_port_in_use() -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as _sock:
try:
_sock.bind(("0.0.0.0", port))
except OSError:
return True
return False

def _select_free_port():
for port in range(range_start, range_end):
if not is_port_in_use(port):
return port
msg = "No free port available"
raise OSError(msg)
return await asyncio.to_thread(_is_port_in_use)


return await asyncio.to_thread(_select_free_port)
async def select_free_port(range_start: int, range_end: int) -> int:
"""Automatically find available port within range."""
for port in range(range_start, range_end):
if not await is_port_in_use(port):
return port
msg = "No free port available"
raise OSError(msg)


async def get_ip_from_host(dns_name: str) -> str | None:
Expand Down
15 changes: 6 additions & 9 deletions music_assistant/server/helpers/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,12 @@ async def close(self) -> int:
if self.returncode is not None:
return self.returncode
# make sure the process is cleaned up
self._proc.terminate()
try:
async with asyncio.timeout(10):
await self.communicate()
except (TimeoutError, asyncio.CancelledError):
self._proc.terminate()
self._proc.kill()
return await self.wait()

async def wait(self) -> int:
Expand All @@ -151,14 +152,10 @@ async def communicate(self, input_data: bytes | None = None) -> tuple[bytes, byt
stdout, stderr = await self._proc.communicate(input_data)
return (stdout, stderr)

async def read_stderr(self, n: int = -1) -> bytes:
"""Read up to n bytes from the stderr stream.

If n is positive, this function try to read n bytes,
and may return less or equal bytes than requested, but at least one byte.
If EOF was received before any byte is read, this function returns empty byte object.
"""
return await self._proc.stderr.read(n)
async def read_stderr(self) -> AsyncGenerator[bytes, None]:
"""Read lines from the stderr stream."""
async for line in self._proc.stderr:
yield line


async def check_output(shell_cmd: str) -> tuple[int, bytes]:
Expand Down
94 changes: 73 additions & 21 deletions music_assistant/server/providers/snapcast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from music_assistant.common.models.media_items import AudioFormat
from music_assistant.common.models.player import DeviceInfo, Player
from music_assistant.server.helpers.audio import get_media_stream
from music_assistant.server.helpers.process import AsyncProcess, check_output
from music_assistant.server.models.player_provider import PlayerProvider
from music_assistant.server.providers.ugp import UGP_PREFIX

Expand All @@ -44,14 +45,16 @@
from music_assistant.server import MusicAssistant
from music_assistant.server.models import ProviderInstanceType

CONF_SNAPCAST_SERVER_HOST = "snapcast_server_host"
CONF_SNAPCAST_SERVER_CONTROL_PORT = "snapcast_server_control_port"
CONF_SERVER_HOST = "snapcast_server_host"
CONF_SERVER_CONTROL_PORT = "snapcast_server_control_port"
CONF_USE_EXTERNAL_SERVER = "snapcast_use_external_server"

SNAP_STREAM_STATUS_MAP = {
"idle": PlayerState.IDLE,
"playing": PlayerState.PLAYING,
"unknown": PlayerState.IDLE,
}
DEFAULT_SNAPSERVER_PORT = 1705


async def setup(
Expand All @@ -64,10 +67,10 @@ async def setup(


async def get_config_entries(
mass: MusicAssistant,
instance_id: str | None = None,
action: str | None = None,
values: dict[str, ConfigValueType] | None = None,
mass: MusicAssistant, # noqa: ARG001
instance_id: str | None = None, # noqa: ARG001
action: str | None = None, # noqa: ARG001
values: dict[str, ConfigValueType] | None = None, # noqa: ARG001
) -> tuple[ConfigEntry, ...]:
"""
Return Config entries to setup this provider.
Expand All @@ -76,21 +79,37 @@ async def get_config_entries(
action: [optional] action key called from config entries UI.
values: the (intermediate) raw values for config entries sent with the action.
"""
# ruff: noqa: ARG001
returncode, output = await check_output("snapserver -v")
snapserver_present = returncode == 0 and "snapserver v0.27.0" in output.decode()
return (
ConfigEntry(
key=CONF_SNAPCAST_SERVER_HOST,
key=CONF_USE_EXTERNAL_SERVER,
type=ConfigEntryType.BOOLEAN,
default_value=not snapserver_present,
label="Use existing Snapserver",
required=False,
description="Music Assistant by default already includes a Snapserver. \n\n"
"Checking this option allows you to connect to your own/external existing Snapserver "
"and not use the builtin one provided by Music Assistant.",
advanced=snapserver_present,
),
ConfigEntry(
key=CONF_SERVER_HOST,
type=ConfigEntryType.STRING,
default_value="127.0.0.1",
label="Snapcast server ip",
required=True,
required=False,
depends_on=CONF_USE_EXTERNAL_SERVER,
advanced=snapserver_present,
),
ConfigEntry(
key=CONF_SNAPCAST_SERVER_CONTROL_PORT,
key=CONF_SERVER_CONTROL_PORT,
type=ConfigEntryType.INTEGER,
default_value="1705",
default_value=DEFAULT_SNAPSERVER_PORT,
label="Snapcast control port",
required=True,
required=False,
depends_on=CONF_USE_EXTERNAL_SERVER,
advanced=snapserver_present,
),
)

Expand All @@ -99,9 +118,12 @@ class SnapCastProvider(PlayerProvider):
"""Player provider for Snapcast based players."""

_snapserver: Snapserver
snapcast_server_host: str
snapcast_server_control_port: int
_snapcast_server_host: str
_snapcast_server_control_port: int
_stream_tasks: dict[str, asyncio.Task]
_use_builtin_server: bool
_snapserver_runner: asyncio.Task | None
_snapserver_started = asyncio.Event | None

@property
def supported_features(self) -> tuple[ProviderFeature, ...]:
Expand All @@ -110,20 +132,29 @@ def supported_features(self) -> tuple[ProviderFeature, ...]:

async def handle_async_init(self) -> None:
"""Handle async initialization of the provider."""
self.snapcast_server_host = self.config.get_value(CONF_SNAPCAST_SERVER_HOST)
self.snapcast_server_control_port = self.config.get_value(CONF_SNAPCAST_SERVER_CONTROL_PORT)
self._snapcast_server_host = self.config.get_value(CONF_SERVER_HOST)
self._snapcast_server_control_port = self.config.get_value(CONF_SERVER_CONTROL_PORT)
self._use_builtin_server = not self.config.get_value(CONF_USE_EXTERNAL_SERVER)
self._stream_tasks = {}
if self._use_builtin_server:
# start our own builtin snapserver
self._snapserver_started = asyncio.Event()
self._snapserver_runner = asyncio.create_task(self._builtin_server_runner())
await asyncio.wait_for(self._snapserver_started.wait(), 10)
else:
self._snapserver_runner = None
self._snapserver_started = None
try:
self._snapserver = await create_server(
self.mass.loop,
self.snapcast_server_host,
port=self.snapcast_server_control_port,
self._snapcast_server_host,
port=self._snapcast_server_control_port,
reconnect=True,
)
self._snapserver.set_on_update_callback(self._handle_update)
self.logger.info(
f"Started Snapserver connection on:"
f"{self.snapcast_server_host}:{self.snapcast_server_control_port}"
"Started connection to Snapserver %s",
f"{self._snapcast_server_host}:{self._snapcast_server_control_port}",
)
except OSError as err:
msg = "Unable to start the Snapserver connection ?"
Expand All @@ -139,6 +170,10 @@ async def unload(self) -> None:
for client in self._snapserver.clients:
await self.cmd_stop(client.identifier)
await self._snapserver.stop()
self._snapserver_started.clear()
if self._snapserver_runner and not self._snapserver_runner.done():
self._snapserver_runner.cancel()
await asyncio.sleep(2) # prevent race conditions when reloading

def _handle_update(self) -> None:
"""Process Snapcast init Player/Group and set callback ."""
Expand Down Expand Up @@ -287,7 +322,7 @@ async def play_media(self, player_id: str, queue_item: QueueItem) -> None:
)

async def _streamer() -> None:
host = self.snapcast_server_host
host = self._snapcast_server_host
_, writer = await asyncio.open_connection(host, port)
self.logger.debug("Opened connection to %s:%s", host, port)
player.current_item_id = f"{queue_item.queue_id}.{queue_item.queue_item_id}"
Expand Down Expand Up @@ -374,3 +409,20 @@ def _set_childs_state(self, player_id: str, state: PlayerState) -> None:
player = self.mass.players.get(child_player_id)
player.state = state
self.mass.players.update(child_player_id)

async def _builtin_server_runner(self) -> None:
"""Start running the builtin snapserver."""
if self._snapserver_started.is_set():
raise RuntimeError("Snapserver is already started!")
logger = self.logger.getChild("snapserver")
logger.info("Starting builtin Snapserver...")
async with AsyncProcess(
["snapserver"], enable_stdin=False, enable_stdout=True, enable_stderr=False
) as snapserver_proc:
# keep reading from stderr until exit
async for data in snapserver_proc.iter_any():
data = data.decode().strip() # noqa: PLW2901
for line in data.split("\n"):
logger.debug(line)
if "Name now registered and active" in line:
self._snapserver_started.set()