From 39835c2852f0289d172ca61e37c024ea53509524 Mon Sep 17 00:00:00 2001 From: alex Date: Mon, 29 Jan 2024 05:14:35 +0100 Subject: [PATCH] add extensions as runtime-checkable protocols, make BaseChannelLayer abstract --- channels/layers.py | 76 +++++++++++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index e64520daa..20b91c439 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -6,8 +6,18 @@ import re import string import time +from abc import ABC, abstractmethod from copy import deepcopy -from typing import Dict, Iterable, List, Optional, Tuple +from typing import ( + Dict, + Iterable, + List, + NoReturn, + Optional, + Protocol, + Tuple, + runtime_checkable, +) from django.conf import settings from django.core.signals import setting_changed @@ -97,7 +107,39 @@ def set(self, key: str, layer: BaseChannelLayer): return old -class BaseChannelLayer: +@runtime_checkable +class WithFlushExtension(Protocol): + async def flush(self) -> NoReturn: + """ + Clears messages and if available groups + """ + + async def close(self) -> NoReturn: + """ + Close connection to the layer. Called before stopping layer. + Unusable after. + """ + + +@runtime_checkable +class WithGroupsExtension(Protocol): + async def group_add(self, group: str, channel: str): + """ + Adds the channel name to a group. + """ + + async def group_discard(self, group: str, channel: str) -> NoReturn: + """ + Removes the channel name from a group when it exists. + """ + + async def group_send(self, group: str, message: dict) -> NoReturn: + """ + Sends message to group + """ + + +class BaseChannelLayer(ABC): """ Base channel layer class that others can inherit from, with useful common functionality. @@ -199,51 +241,29 @@ def non_local_name(self, name: str) -> str: else: return name + @abstractmethod async def send(self, channel: str, message: dict): """ Send a message onto a (general or specific) channel. """ - raise NotImplementedError() + @abstractmethod async def receive(self, channel: str) -> dict: """ Receive the first message that arrives on the channel. If more than one coroutine waits on the same channel, a random one of the waiting coroutines will get the result. """ - raise NotImplementedError() + @abstractmethod async def new_channel(self, prefix: str = "specific.") -> str: """ Returns a new channel name that can be used by something in our process as a specific channel. """ - raise NotImplementedError() - - # Flush extension - - async def flush(self): - raise NotImplementedError() - - async def close(self): - raise NotImplementedError() - - # Groups extension - - async def group_add(self, group: str, channel: str): - """ - Adds the channel name to a group. - """ - raise NotImplementedError() - - async def group_discard(self, group: str, channel: str): - raise NotImplementedError() - - async def group_send(self, group: str, message: dict): - raise NotImplementedError() -class InMemoryChannelLayer(BaseChannelLayer): +class InMemoryChannelLayer(WithFlushExtension, WithGroupsExtension, BaseChannelLayer): """ In-memory channel layer implementation """