From a13f9eb96d75edced54e02da9cff14f5284ed882 Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Tue, 21 Jan 2025 17:48:37 +0530 Subject: [PATCH] change the function name to require_valid_group_name and remove asserts whenever necessary --- channels/layers.py | 12 +++++------- tests/test_layers.py | 7 +++++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index 09d6e53b..12656219 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -159,7 +159,7 @@ def valid_channel_name(self, name, receive=False): return True raise TypeError(self.invalid_name_error.format("Channel", name)) - def valid_group_name(self, name): + def require_valid_group_name(self, name): if len(name) >= self.MAX_NAME_LENGTH: raise TypeError( f"Group name must be less than {self.MAX_NAME_LENGTH} characters." @@ -345,8 +345,8 @@ async def group_add(self, group, channel): Adds the channel name to a group. """ # Check the inputs - assert self.valid_group_name(group) - assert self.valid_channel_name(channel), "Channel name not valid" + self.require_valid_group_name(group) + self.valid_channel_name(channel), "Channel name not valid" # Add to group dict self.groups.setdefault(group, {}) self.groups[group][channel] = time.time() @@ -354,7 +354,7 @@ async def group_add(self, group, channel): async def group_discard(self, group, channel): # Both should be text and valid assert self.valid_channel_name(channel), "Invalid channel name" - assert self.valid_group_name(group) + self.require_valid_group_name(group) # Remove from group set group_channels = self.groups.get(group, None) if group_channels: @@ -367,9 +367,7 @@ async def group_discard(self, group, channel): async def group_send(self, group, message): # Check types assert isinstance(message, dict), "Message is not a dict" - assert self.valid_group_name(group), ( - f"Group name must be" f"less than {self.MAX_NAME_LENGTH} characters." - ) + self.require_valid_group_name(group) # Run clean self._clean_expired() diff --git a/tests/test_layers.py b/tests/test_layers.py index 407e2839..21a0b463 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -72,7 +72,10 @@ async def test_send_receive(): @pytest.mark.parametrize( "method", - [BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name], + [ + BaseChannelLayer().valid_channel_name, + BaseChannelLayer().require_valid_group_name, + ], ) @pytest.mark.parametrize( "channel_name,expected_valid", @@ -104,4 +107,4 @@ def test_group_name_length_error_message(name, expected_error_message): layer = BaseChannelLayer() with pytest.raises(TypeError, match=expected_error_message): - layer.valid_group_name(name) + layer.require_valid_group_name(name)