From 2e8c005b080385b688a0ab9a39d80d564b2f8344 Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Fri, 10 Jan 2025 20:00:46 +0530 Subject: [PATCH 01/12] add warning of length of group --- channels/layers.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index 99e7fbd6..91340abc 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -5,7 +5,8 @@ import string import time from copy import deepcopy - +import logging +logger = logging.getLogger(__name__) from django.conf import settings from django.core.signals import setting_changed from django.utils.module_loading import import_string @@ -160,7 +161,10 @@ def valid_channel_name(self, name, receive=False): raise TypeError(self.invalid_name_error.format("Channel", name)) def valid_group_name(self, name): - if self.match_type_and_length(name): + logger.debug(f"Validating group name: {name}, Length: {len(name)}") # Log group name length + if isinstance(name, str): + if len(name) >= self.MAX_NAME_LENGTH: + raise False if bool(self.group_name_regex.match(name)): return True raise TypeError(self.invalid_name_error.format("Group", name)) @@ -201,6 +205,7 @@ async def flush(self): raise NotImplementedError("flush() not implemented (flush extension)") async def group_add(self, group, channel): + print("Hello") raise NotImplementedError("group_add() not implemented (groups extension)") async def group_discard(self, group, channel): @@ -336,19 +341,28 @@ def _remove_from_groups(self, channel): # Groups extension + async def group_add(self, group, channel): """ Adds the channel name to a group. """ # Check the inputs - assert self.valid_group_name(group), "Group name not valid" + print(f"Validating group name: {group}") + assert self.valid_group_name(group), f"Group name must be less than {self.MAX_NAME_LENGTH} characters." assert self.valid_channel_name(channel), "Channel name not valid" - # Add to group dict + + # Check the length of the group name + if len(group) >= self.MAX_NAME_LENGTH: + raise TypeError(f"Group name must be less than {self.MAX_NAME_LENGTH} characters, but got {len(group)}.") + + # Add to group dict self.groups.setdefault(group, {}) self.groups[group][channel] = time.time() + async def group_discard(self, group, channel): # Both should be text and valid + print(f"Discarding channel {channel} from group {group}") # Log group name assert self.valid_channel_name(channel), "Invalid channel name" assert self.valid_group_name(group), "Invalid group name" # Remove from group set From a58f8d93b0faa5b34a1754a294765c1247da1504 Mon Sep 17 00:00:00 2001 From: IronJam <148959043+IronJam11@users.noreply.github.com> Date: Sat, 11 Jan 2025 20:20:32 +0530 Subject: [PATCH 02/12] fix minor issues to pass the tests --- channels/layers.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index 91340abc..b90e86db 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -5,8 +5,7 @@ import string import time from copy import deepcopy -import logging -logger = logging.getLogger(__name__) + from django.conf import settings from django.core.signals import setting_changed from django.utils.module_loading import import_string @@ -161,10 +160,12 @@ def valid_channel_name(self, name, receive=False): raise TypeError(self.invalid_name_error.format("Channel", name)) def valid_group_name(self, name): - logger.debug(f"Validating group name: {name}, Length: {len(name)}") # Log group name length - if isinstance(name, str): - if len(name) >= self.MAX_NAME_LENGTH: - raise False + error_message = ( + f"Group name must be less than {self.MAX_NAME_LENGTH} characters." + ) + if len(name) >= self.MAX_NAME_LENGTH: + raise TypeError(error_message) + if self.match_type_and_length(name): if bool(self.group_name_regex.match(name)): return True raise TypeError(self.invalid_name_error.format("Group", name)) @@ -205,7 +206,6 @@ async def flush(self): raise NotImplementedError("flush() not implemented (flush extension)") async def group_add(self, group, channel): - print("Hello") raise NotImplementedError("group_add() not implemented (groups extension)") async def group_discard(self, group, channel): @@ -341,30 +341,25 @@ def _remove_from_groups(self, channel): # Groups extension - async def group_add(self, group, channel): """ Adds the channel name to a group. """ # Check the inputs - print(f"Validating group name: {group}") - assert self.valid_group_name(group), f"Group name must be less than {self.MAX_NAME_LENGTH} characters." + assert self.valid_group_name(group), ( + f"Group name must be" f"less than {self.MAX_NAME_LENGTH} characters." + ) assert self.valid_channel_name(channel), "Channel name not valid" - - # Check the length of the group name - if len(group) >= self.MAX_NAME_LENGTH: - raise TypeError(f"Group name must be less than {self.MAX_NAME_LENGTH} characters, but got {len(group)}.") - - # Add to group dict + # Add to group dict self.groups.setdefault(group, {}) self.groups[group][channel] = time.time() - async def group_discard(self, group, channel): # Both should be text and valid - print(f"Discarding channel {channel} from group {group}") # Log group name assert self.valid_channel_name(channel), "Invalid channel name" - assert self.valid_group_name(group), "Invalid group name" + assert self.valid_group_name(group), ( + f"Group name must be" f"less than {self.MAX_NAME_LENGTH} characters." + ) # Remove from group set group_channels = self.groups.get(group, None) if group_channels: @@ -377,7 +372,9 @@ async def group_discard(self, group, channel): async def group_send(self, group, message): # Check types assert isinstance(message, dict), "Message is not a dict" - assert self.valid_group_name(group), "Invalid group name" + assert self.valid_group_name(group), ( + f"Group name must be" f"less than {self.MAX_NAME_LENGTH} characters." + ) # Run clean self._clean_expired() From 19879b2f44253dbecbead83c28f2b490878dcf96 Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Mon, 13 Jan 2025 20:27:01 +0530 Subject: [PATCH 03/12] add regression test --- tests/test_layers.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 543a9f19..b6ea57a6 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -69,18 +69,17 @@ async def test_send_receive(): await layer.send("test.channel", message) assert message == await layer.receive("test.channel") - -@pytest.mark.parametrize( - "method", - [BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name], -) @pytest.mark.parametrize( - "channel_name,expected_valid", - [("¯\\_(ツ)_/¯", False), ("chat", True), ("chat" * 100, False)], + "name, expected_error_message", + [ + ("a" * 101, f"Group name must be less than {BaseChannelLayer.MAX_NAME_LENGTH} characters."), # Group name too long + ], ) -def test_channel_and_group_name_validation(method, channel_name, expected_valid): - if expected_valid: - method(channel_name) - else: - with pytest.raises(TypeError): - method(channel_name) +def test_group_name_length_error_message(name, expected_error_message): + """ + Ensure the correct error message is raised when group names exceed the character limit. + """ + layer = BaseChannelLayer() + + with pytest.raises(TypeError, match=expected_error_message): + layer.valid_group_name(name) \ No newline at end of file From 1bbd2d9c41871b3594904766fed81da040db06e8 Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Mon, 13 Jan 2025 20:37:35 +0530 Subject: [PATCH 04/12] update the structure of the test code --- tests/test_layers.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index b6ea57a6..9b248ea5 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -22,7 +22,6 @@ def test_config_error(self): If channel layer doesn't specify TEST_CONFIG, `make_test_backend` should result into error. """ - with self.assertRaises(InvalidChannelLayerError): channel_layers.make_test_backend(DEFAULT_CHANNEL_LAYER) @@ -39,7 +38,6 @@ def test_config_instance(self): If channel layer provides TEST_CONFIG, `make_test_backend` should return channel layer instance appropriate for testing. """ - layer = channel_layers.make_test_backend(DEFAULT_CHANNEL_LAYER) self.assertEqual(layer.expiry, 100500) @@ -64,22 +62,31 @@ def test_override_settings(self): @pytest.mark.asyncio async def test_send_receive(): + """ + Test that a message sent to a channel can be received correctly. + """ layer = InMemoryChannelLayer() message = {"type": "test.message"} await layer.send("test.channel", message) assert message == await layer.receive("test.channel") + @pytest.mark.parametrize( "name, expected_error_message", [ - ("a" * 101, f"Group name must be less than {BaseChannelLayer.MAX_NAME_LENGTH} characters."), # Group name too long + ( + "a" * 101, + f"Group name must be less than {BaseChannelLayer.MAX_NAME_LENGTH} " + "characters.", + ), # Group name too long ], ) def test_group_name_length_error_message(name, expected_error_message): """ - Ensure the correct error message is raised when group names exceed the character limit. + Ensure the correct error message is raised when group names + exceed the character limit. """ layer = BaseChannelLayer() with pytest.raises(TypeError, match=expected_error_message): - layer.valid_group_name(name) \ No newline at end of file + layer.valid_group_name(name) From cc46101886a8351de4bd4fac37409c80679bf8fa Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Tue, 14 Jan 2025 19:26:26 +0530 Subject: [PATCH 05/12] add the erased test --- tests/test_layers.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 9b248ea5..407e2839 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -22,6 +22,7 @@ def test_config_error(self): If channel layer doesn't specify TEST_CONFIG, `make_test_backend` should result into error. """ + with self.assertRaises(InvalidChannelLayerError): channel_layers.make_test_backend(DEFAULT_CHANNEL_LAYER) @@ -38,6 +39,7 @@ def test_config_instance(self): If channel layer provides TEST_CONFIG, `make_test_backend` should return channel layer instance appropriate for testing. """ + layer = channel_layers.make_test_backend(DEFAULT_CHANNEL_LAYER) self.assertEqual(layer.expiry, 100500) @@ -62,15 +64,28 @@ def test_override_settings(self): @pytest.mark.asyncio async def test_send_receive(): - """ - Test that a message sent to a channel can be received correctly. - """ layer = InMemoryChannelLayer() message = {"type": "test.message"} await layer.send("test.channel", message) assert message == await layer.receive("test.channel") +@pytest.mark.parametrize( + "method", + [BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name], +) +@pytest.mark.parametrize( + "channel_name,expected_valid", + [("¯\\_(ツ)_/¯", False), ("chat", True), ("chat" * 100, False)], +) +def test_channel_and_group_name_validation(method, channel_name, expected_valid): + if expected_valid: + method(channel_name) + else: + with pytest.raises(TypeError): + method(channel_name) + + @pytest.mark.parametrize( "name, expected_error_message", [ From 50f53bc04b97a3d811f36b670c712dcce31b11a4 Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Thu, 16 Jan 2025 01:08:24 +0530 Subject: [PATCH 06/12] remove unnecessary assertions --- channels/layers.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index b90e86db..bb111b0d 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -346,9 +346,7 @@ async def group_add(self, group, channel): Adds the channel name to a group. """ # Check the inputs - assert self.valid_group_name(group), ( - f"Group name must be" f"less than {self.MAX_NAME_LENGTH} characters." - ) + assert self.valid_group_name(group) assert self.valid_channel_name(channel), "Channel name not valid" # Add to group dict self.groups.setdefault(group, {}) @@ -357,9 +355,7 @@ async def group_add(self, group, channel): async def group_discard(self, group, channel): # Both should be text and valid assert self.valid_channel_name(channel), "Invalid channel name" - assert self.valid_group_name(group), ( - f"Group name must be" f"less than {self.MAX_NAME_LENGTH} characters." - ) + assert self.valid_group_name(group) # Remove from group set group_channels = self.groups.get(group, None) if group_channels: From 48533bdcc45547fd27aae457554aed93ec1cc09c Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Tue, 21 Jan 2025 17:23:30 +0530 Subject: [PATCH 07/12] inline the error_message --- channels/layers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index bb111b0d..09d6e53b 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -160,11 +160,10 @@ def valid_channel_name(self, name, receive=False): raise TypeError(self.invalid_name_error.format("Channel", name)) def valid_group_name(self, name): - error_message = ( - f"Group name must be less than {self.MAX_NAME_LENGTH} characters." - ) if len(name) >= self.MAX_NAME_LENGTH: - raise TypeError(error_message) + raise TypeError( + f"Group name must be less than {self.MAX_NAME_LENGTH} characters." + ) if self.match_type_and_length(name): if bool(self.group_name_regex.match(name)): return True From a13f9eb96d75edced54e02da9cff14f5284ed882 Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Tue, 21 Jan 2025 17:48:37 +0530 Subject: [PATCH 08/12] change the function name to require_valid_group_name and remove asserts whenever necessary --- channels/layers.py | 12 +++++------- tests/test_layers.py | 7 +++++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index 09d6e53b..12656219 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -159,7 +159,7 @@ def valid_channel_name(self, name, receive=False): return True raise TypeError(self.invalid_name_error.format("Channel", name)) - def valid_group_name(self, name): + def require_valid_group_name(self, name): if len(name) >= self.MAX_NAME_LENGTH: raise TypeError( f"Group name must be less than {self.MAX_NAME_LENGTH} characters." @@ -345,8 +345,8 @@ async def group_add(self, group, channel): Adds the channel name to a group. """ # Check the inputs - assert self.valid_group_name(group) - assert self.valid_channel_name(channel), "Channel name not valid" + self.require_valid_group_name(group) + self.valid_channel_name(channel), "Channel name not valid" # Add to group dict self.groups.setdefault(group, {}) self.groups[group][channel] = time.time() @@ -354,7 +354,7 @@ async def group_add(self, group, channel): async def group_discard(self, group, channel): # Both should be text and valid assert self.valid_channel_name(channel), "Invalid channel name" - assert self.valid_group_name(group) + self.require_valid_group_name(group) # Remove from group set group_channels = self.groups.get(group, None) if group_channels: @@ -367,9 +367,7 @@ async def group_discard(self, group, channel): async def group_send(self, group, message): # Check types assert isinstance(message, dict), "Message is not a dict" - assert self.valid_group_name(group), ( - f"Group name must be" f"less than {self.MAX_NAME_LENGTH} characters." - ) + self.require_valid_group_name(group) # Run clean self._clean_expired() diff --git a/tests/test_layers.py b/tests/test_layers.py index 407e2839..21a0b463 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -72,7 +72,10 @@ async def test_send_receive(): @pytest.mark.parametrize( "method", - [BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name], + [ + BaseChannelLayer().valid_channel_name, + BaseChannelLayer().require_valid_group_name, + ], ) @pytest.mark.parametrize( "channel_name,expected_valid", @@ -104,4 +107,4 @@ def test_group_name_length_error_message(name, expected_error_message): layer = BaseChannelLayer() with pytest.raises(TypeError, match=expected_error_message): - layer.valid_group_name(name) + layer.require_valid_group_name(name) From 792c79e4d807f39ffaff4bb7799dde46c6451905 Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Wed, 22 Jan 2025 13:02:02 +0530 Subject: [PATCH 09/12] applying the same changes for channel --- channels/layers.py | 48 +++++++++++++++++++------------------------- tests/test_layers.py | 33 +++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index 12656219..801abd89 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -144,38 +144,32 @@ def match_type_and_length(self, name): invalid_name_error = ( "{} name must be a valid unicode string " + "with length < {} ".format(MAX_NAME_LENGTH) - + "containing only ASCII alphanumerics, hyphens, underscores, or periods, " - + "not {}" + + "containing only ASCII alphanumerics, hyphens, underscores, or periods." ) - def valid_channel_name(self, name, receive=False): - if self.match_type_and_length(name): - if bool(self.channel_name_regex.match(name)): - # Check cases for special channels - if "!" in name and not name.endswith("!") and receive: - raise TypeError( - "Specific channel names in receive() must end at the !" - ) - return True - raise TypeError(self.invalid_name_error.format("Channel", name)) + def require_valid_channel_name(self, name, receive=False): + if not self.match_type_and_length(name): + raise TypeError(self.invalid_name_error.format("Channel")) + if not bool(self.channel_name_regex.match(name)): + raise TypeError(self.invalid_name_error.format("Channel")) + if "!" in name and not name.endswith("!") and receive: + raise TypeError("Specific channel names in receive() must end at the !") + return True def require_valid_group_name(self, name): - if len(name) >= self.MAX_NAME_LENGTH: - raise TypeError( - f"Group name must be less than {self.MAX_NAME_LENGTH} characters." - ) - if self.match_type_and_length(name): - if bool(self.group_name_regex.match(name)): - return True - raise TypeError(self.invalid_name_error.format("Group", name)) + if not self.match_type_and_length(name): + raise TypeError(self.invalid_name_error.format("Group")) + if not bool(self.group_name_regex.match(name)): + raise TypeError(self.invalid_name_error.format("Group")) + return True def valid_channel_names(self, names, receive=False): _non_empty_list = True if names else False _names_type = isinstance(names, list) assert _non_empty_list and _names_type, "names must be a non-empty list" - - assert all( - self.valid_channel_name(channel, receive=receive) for channel in names + all( + self.require_valid_channel_name(channel, receive=receive) + for channel in names ) return True @@ -247,7 +241,7 @@ async def send(self, channel, message): """ # Typecheck assert isinstance(message, dict), "message is not a dict" - assert self.valid_channel_name(channel), "Channel name not valid" + self.require_valid_channel_name(channel) # If it's a process-local channel, strip off local part and stick full # name in message assert "__asgi_channel__" not in message @@ -267,7 +261,7 @@ async def receive(self, channel): If more than one coroutine waits on the same channel, a random one of the waiting coroutines will get the result. """ - assert self.valid_channel_name(channel) + self.require_valid_channel_name(channel) self._clean_expired() queue = self.channels.setdefault( @@ -346,14 +340,14 @@ async def group_add(self, group, channel): """ # Check the inputs self.require_valid_group_name(group) - self.valid_channel_name(channel), "Channel name not valid" + self.require_valid_channel_name(channel) # Add to group dict self.groups.setdefault(group, {}) self.groups[group][channel] = time.time() async def group_discard(self, group, channel): # Both should be text and valid - assert self.valid_channel_name(channel), "Invalid channel name" + self.require_valid_channel_name(channel) self.require_valid_group_name(group) # Remove from group set group_channels = self.groups.get(group, None) diff --git a/tests/test_layers.py b/tests/test_layers.py index 21a0b463..7b02c155 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -73,7 +73,7 @@ async def test_send_receive(): @pytest.mark.parametrize( "method", [ - BaseChannelLayer().valid_channel_name, + BaseChannelLayer().require_valid_channel_name, BaseChannelLayer().require_valid_group_name, ], ) @@ -90,21 +90,36 @@ def test_channel_and_group_name_validation(method, channel_name, expected_valid) @pytest.mark.parametrize( - "name, expected_error_message", + "name", [ - ( - "a" * 101, - f"Group name must be less than {BaseChannelLayer.MAX_NAME_LENGTH} " - "characters.", - ), # Group name too long + "a" * 101, # Group name too long ], ) -def test_group_name_length_error_message(name, expected_error_message): +def test_group_name_length_error_message(name): """ Ensure the correct error message is raised when group names - exceed the character limit. + exceed the character limit or contain invalid characters. """ layer = BaseChannelLayer() + expected_error_message = layer.invalid_name_error.format("Group") with pytest.raises(TypeError, match=expected_error_message): layer.require_valid_group_name(name) + + +@pytest.mark.parametrize( + "name", + [ + "a" * 101, # Channel name too long + ], +) +def test_channel_name_length_error_message(name): + """ + Ensure the correct error message is raised when group names + exceed the character limit or contain invalid characters. + """ + layer = BaseChannelLayer() + expected_error_message = layer.invalid_name_error.format("Channel") + + with pytest.raises(TypeError, match=expected_error_message): + layer.require_valid_channel_name(name) From ae711a352b10869581b1cab33df0c53ca0ce45ef Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Sun, 26 Jan 2025 14:14:37 +0530 Subject: [PATCH 10/12] refactor all to use for loop --- channels/layers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/channels/layers.py b/channels/layers.py index 801abd89..d5632032 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -167,12 +167,9 @@ def valid_channel_names(self, names, receive=False): _non_empty_list = True if names else False _names_type = isinstance(names, list) assert _non_empty_list and _names_type, "names must be a non-empty list" - all( + for channel in names: self.require_valid_channel_name(channel, receive=receive) - for channel in names - ) return True - def non_local_name(self, name): """ Given a channel name, returns the "non-local" part. If the channel name From 027aad1d3689bcc05995c8d392aebf4117dc5ddd Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Mon, 27 Jan 2025 11:09:31 +0530 Subject: [PATCH 11/12] remove lint errors --- channels/layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/channels/layers.py b/channels/layers.py index d5632032..8aaaddec 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -169,7 +169,8 @@ def valid_channel_names(self, names, receive=False): assert _non_empty_list and _names_type, "names must be a non-empty list" for channel in names: self.require_valid_channel_name(channel, receive=receive) - return True + return + def non_local_name(self, name): """ Given a channel name, returns the "non-local" part. If the channel name From 8b6c547e1c0fe84a1503aebf507bdc490d7f7222 Mon Sep 17 00:00:00 2001 From: Aaryan Jain Date: Mon, 27 Jan 2025 17:04:55 +0530 Subject: [PATCH 12/12] correct the return instance --- channels/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/channels/layers.py b/channels/layers.py index 8aaaddec..5fc53f74 100644 --- a/channels/layers.py +++ b/channels/layers.py @@ -169,7 +169,7 @@ def valid_channel_names(self, names, receive=False): assert _non_empty_list and _names_type, "names must be a non-empty list" for channel in names: self.require_valid_channel_name(channel, receive=receive) - return + return True def non_local_name(self, name): """