Skip to content

Commit

Permalink
Add warning for the length of the group name (#2122)
Browse files Browse the repository at this point in the history
  • Loading branch information
IronJam11 authored Jan 28, 2025
1 parent b502c73 commit a144b4b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 30 deletions.
54 changes: 25 additions & 29 deletions channels/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,35 +144,31 @@ def match_type_and_length(self, name):
invalid_name_error = (
"{} name must be a valid unicode string "
+ "with length < {} ".format(MAX_NAME_LENGTH)
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
+ "not {}"
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods."
)

def valid_channel_name(self, name, receive=False):
if self.match_type_and_length(name):
if bool(self.channel_name_regex.match(name)):
# Check cases for special channels
if "!" in name and not name.endswith("!") and receive:
raise TypeError(
"Specific channel names in receive() must end at the !"
)
return True
raise TypeError(self.invalid_name_error.format("Channel", name))

def valid_group_name(self, name):
if self.match_type_and_length(name):
if bool(self.group_name_regex.match(name)):
return True
raise TypeError(self.invalid_name_error.format("Group", name))
def require_valid_channel_name(self, name, receive=False):
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Channel"))
if not bool(self.channel_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Channel"))
if "!" in name and not name.endswith("!") and receive:
raise TypeError("Specific channel names in receive() must end at the !")
return True

def require_valid_group_name(self, name):
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Group"))
if not bool(self.group_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Group"))
return True

def valid_channel_names(self, names, receive=False):
_non_empty_list = True if names else False
_names_type = isinstance(names, list)
assert _non_empty_list and _names_type, "names must be a non-empty list"

assert all(
self.valid_channel_name(channel, receive=receive) for channel in names
)
for channel in names:
self.require_valid_channel_name(channel, receive=receive)
return True

def non_local_name(self, name):
Expand Down Expand Up @@ -243,7 +239,7 @@ async def send(self, channel, message):
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
self.require_valid_channel_name(channel)
# If it's a process-local channel, strip off local part and stick full
# name in message
assert "__asgi_channel__" not in message
Expand All @@ -263,7 +259,7 @@ async def receive(self, channel):
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
assert self.valid_channel_name(channel)
self.require_valid_channel_name(channel)
self._clean_expired()

queue = self.channels.setdefault(
Expand Down Expand Up @@ -341,16 +337,16 @@ async def group_add(self, group, channel):
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
self.require_valid_group_name(group)
self.require_valid_channel_name(channel)
# Add to group dict
self.groups.setdefault(group, {})
self.groups[group][channel] = time.time()

async def group_discard(self, group, channel):
# Both should be text and valid
assert self.valid_channel_name(channel), "Invalid channel name"
assert self.valid_group_name(group), "Invalid group name"
self.require_valid_channel_name(channel)
self.require_valid_group_name(group)
# Remove from group set
group_channels = self.groups.get(group, None)
if group_channels:
Expand All @@ -363,7 +359,7 @@ async def group_discard(self, group, channel):
async def group_send(self, group, message):
# Check types
assert isinstance(message, dict), "Message is not a dict"
assert self.valid_group_name(group), "Invalid group name"
self.require_valid_group_name(group)
# Run clean
self._clean_expired()

Expand Down
41 changes: 40 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ async def test_send_receive():

@pytest.mark.parametrize(
"method",
[BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name],
[
BaseChannelLayer().require_valid_channel_name,
BaseChannelLayer().require_valid_group_name,
],
)
@pytest.mark.parametrize(
"channel_name,expected_valid",
Expand All @@ -84,3 +87,39 @@ def test_channel_and_group_name_validation(method, channel_name, expected_valid)
else:
with pytest.raises(TypeError):
method(channel_name)


@pytest.mark.parametrize(
"name",
[
"a" * 101, # Group name too long
],
)
def test_group_name_length_error_message(name):
"""
Ensure the correct error message is raised when group names
exceed the character limit or contain invalid characters.
"""
layer = BaseChannelLayer()
expected_error_message = layer.invalid_name_error.format("Group")

with pytest.raises(TypeError, match=expected_error_message):
layer.require_valid_group_name(name)


@pytest.mark.parametrize(
"name",
[
"a" * 101, # Channel name too long
],
)
def test_channel_name_length_error_message(name):
"""
Ensure the correct error message is raised when group names
exceed the character limit or contain invalid characters.
"""
layer = BaseChannelLayer()
expected_error_message = layer.invalid_name_error.format("Channel")

with pytest.raises(TypeError, match=expected_error_message):
layer.require_valid_channel_name(name)

0 comments on commit a144b4b

Please sign in to comment.