diff --git a/CHANGELOG.md b/CHANGELOG.md index 749ffaf53..721907fb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * Collaborative editing in project history diff views * Project history diff views: add revert changes button to markdown editor * Fix MDE preview layout break on zoom out +* Send update_text events with text diff when updating text fields via API instead of overwriting the whole text ## v2024.040 - 2024-05-15 diff --git a/api/src/reportcreator_api/conf/settings.py b/api/src/reportcreator_api/conf/settings.py index bc894913b..25bf85a47 100644 --- a/api/src/reportcreator_api/conf/settings.py +++ b/api/src/reportcreator_api/conf/settings.py @@ -666,10 +666,12 @@ def __bool__(self): 'formatter': 'default', 'class': 'logging.StreamHandler', }, - 'elasticapm': { - 'level': 'WARNING', - 'class': 'elasticapm.contrib.django.handlers.LoggingHandler', - }, + **({ + 'elasticapm': { + 'level': 'WARNING', + 'class': 'elasticapm.contrib.django.handlers.LoggingHandler', + }, + } if ELASTIC_APM_ENABLED else {}), }, 'root': { 'level': 'INFO', diff --git a/api/src/reportcreator_api/conf/urls.py b/api/src/reportcreator_api/conf/urls.py index 3b0290799..3e1cc3ff4 100644 --- a/api/src/reportcreator_api/conf/urls.py +++ b/api/src/reportcreator_api/conf/urls.py @@ -10,6 +10,7 @@ from reportcreator_api.api_utils.views import UtilsViewSet from reportcreator_api.notifications.views import NotificationViewSet +from reportcreator_api.pentests.collab.channels import ConsumerHttpFallbackView from reportcreator_api.pentests.consumers import ProjectNotesConsumer, ProjectReportingConsumer, UserNotesConsumer from reportcreator_api.pentests.views import ( ArchivedProjectKeyPartViewSet, @@ -40,7 +41,6 @@ MFAMethodViewSet, PentestUserViewSet, ) -from reportcreator_api.utils.channels import ConsumerHttpFallbackView router = DefaultRouter() # Make trailing slash in URL optional to support loading images and assets by fielname diff --git a/api/src/reportcreator_api/pentests/collab/__init__.py b/api/src/reportcreator_api/pentests/collab/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/src/reportcreator_api/utils/channels.py b/api/src/reportcreator_api/pentests/collab/channels.py similarity index 100% rename from api/src/reportcreator_api/utils/channels.py rename to api/src/reportcreator_api/pentests/collab/channels.py diff --git a/api/src/reportcreator_api/pentests/collab/consumer_base.py b/api/src/reportcreator_api/pentests/collab/consumer_base.py new file mode 100644 index 000000000..4476c6e46 --- /dev/null +++ b/api/src/reportcreator_api/pentests/collab/consumer_base.py @@ -0,0 +1,387 @@ +import itertools +import json +import logging +from datetime import timedelta +from functools import cached_property +from typing import Any + +from channels.db import database_sync_to_async +from channels.exceptions import DenyConnection, StopConsumer +from channels.generic.websocket import AsyncJsonWebsocketConsumer +from django.core.exceptions import ValidationError +from django.core.serializers.json import DjangoJSONEncoder +from django.db import transaction +from django.utils import timezone +from django.utils.crypto import get_random_string +from randomcolor import RandomColor + +from reportcreator_api.pentests.collab.text_transformations import EditorSelection, Update, rebase_updates +from reportcreator_api.pentests.models import ( + CollabClientInfo, + CollabEvent, + CollabEventType, +) +from reportcreator_api.pentests.models.collab import collab_context +from reportcreator_api.users.serializers import PentestUserSerializer +from reportcreator_api.utils.elasticapm import elasticapm_capture_websocket_transaction +from reportcreator_api.utils.history import history_context +from reportcreator_api.utils.utils import aretry + +log = logging.getLogger(__name__) + + +class WebsocketConsumerBase(AsyncJsonWebsocketConsumer): + last_permission_check_time = None + initial_path = None + + @property + def related_id(self): + raise NotImplementedError() + + async def dispatch(self, message): + try: + if not message.get('type', '').startswith('websocket.') and not await self.check_permission(action='read', skip_on_recent_check=True): + await self.close(code=4443) + return + + with history_context(history_user=self.scope.get('user')): + await super().dispatch(message) + except StopConsumer: + await self.delete_client_info() + raise + except Exception as ex: + await self.delete_client_info() + log.exception(ex) + raise ex + + async def websocket_connect(self, message): + async with elasticapm_capture_websocket_transaction(scope=self.scope, event={'type': 'websocket.connect'}): + # Log connection + user = '' + if self.scope.get('user') and not self.scope['user'].is_anonymous: + user = self.scope['user'].username + logging.info(f'CONNECT {self.scope['path']} (user={user})') + + # Set user.admin_permissions_enabled + if self.scope.get('user') and self.scope.get('session', {}).get('admin_permissions_enabled'): + self.scope['user'].admin_permissions_enabled = True + + with history_context(history_user=self.scope.get('user')): + return await super().websocket_connect(message) + + async def websocket_receive(self, message): + event = await self.decode_json(message.get('text', '{}')) + if event.get('type') == 'ping': + await self.send_json({'type': 'ping'}) + return + + async with elasticapm_capture_websocket_transaction(scope=self.scope, event=event): + if not await self.check_permission(action='write', event=event): + await self.close(code=4443) + return + + try: + with history_context(history_user=self.scope.get('user')): + return await super().websocket_receive(message) + except ValidationError as ex: + await self.send_json({ + 'type': 'error', + 'message': ex.message, + }) + + async def websocket_disconnect(self, message): + try: + return await super().websocket_disconnect(message) + finally: + user = '' + if self.scope.get('user') and not self.scope['user'].is_anonymous: + user = self.scope['user'].username + logging.info(f'DISCONNECT {self.scope['path']} (user={user})') + + async def encode_json(self, content): + return json.dumps(content, cls=DjangoJSONEncoder) + + @property + def group_name(self) -> str: + raise NotImplementedError() + + @cached_property + def client_id(self) -> str: + return self.scope.get('client_id') or f'{self.scope['user'].id}/{get_random_string(8)}' + + @cached_property + def client_color(self) -> str: + return RandomColor(seed=get_random_string(8)).generate(luminosity='bright')[0] + + @database_sync_to_async + def check_permission(self, skip_on_recent_check=False, action=None, **kwargs): + # Skip permission check if it was done recently + if skip_on_recent_check and self.last_permission_check_time and self.last_permission_check_time + timedelta(seconds=60) >= timezone.now(): + return True + + # Check if session is still valid + session = self.scope.get('session') + if not session or not session.session_key or \ + session.expire_date < timezone.now() or \ + not session.exists(session.session_key): + return False + + # Check custom permissions + res = self.has_permission(action=action, **kwargs) + self.last_permission_check_time = timezone.now() + return res + + def has_permission(self, **kwargs): + return True + + @database_sync_to_async + def create_client_info(self): + CollabClientInfo.objects.create( + related_id=self.related_id, + user=self.scope['user'], + client_id=self.client_id, + client_color=self.client_color, + path=self.initial_path, + ) + + @database_sync_to_async + def delete_client_info(self): + CollabClientInfo.objects \ + .filter(client_id=self.client_id) \ + .delete() + + def filter_path(self, qs_or_obj): + return qs_or_obj + + def get_client_infos(self): + clients = CollabClientInfo.objects \ + .filter(related_id=self.related_id) \ + .select_related('user') + clients = self.filter_path(clients) + + return [{ + 'client_id': c.client_id, + 'client_color': c.client_color, + 'user': PentestUserSerializer(c.user).data, + 'path': c.path, + } for c in clients] + + async def get_initial_message(self): + return None + + async def get_connect_message(self): + return { + 'type': CollabEventType.CONNECT, + 'client_id': self.client_id, + 'path': self.initial_path, + 'client': { + 'client_id': self.client_id, + 'client_color': self.client_color, + 'user': PentestUserSerializer(self.scope['user']).data, + }, + } + + async def get_disconnect_message(self): + return { + 'type': CollabEventType.DISCONNECT, + 'client_id': self.client_id, + 'path': self.initial_path, + } + + async def connect(self): + if not await self.check_permission(action='connect'): + raise DenyConnection() + + await super().connect() + await self.create_client_info() + if initial_msg := await self.get_initial_message(): + await self.send_json(initial_msg) + + await self.channel_layer.group_add(self.group_name, self.channel_name) + if connect_msg := await self.get_connect_message(): + await self.send_colllab_event(connect_msg) + + async def disconnect(self, close_code): + await self.channel_layer.group_discard(self.group_name, self.channel_name) + await self.delete_client_info() + if disconnect_msg := await self.get_disconnect_message(): + await self.send_colllab_event(disconnect_msg) + await super().disconnect(close_code) + + async def send_colllab_event(self, event): + if not event: + return + elif isinstance(event, CollabEvent): + await self.channel_layer.group_send(self.group_name, { + 'type': 'collab_event', + 'id': str(event.id), + 'path': event.path, + }) + else: + await self.channel_layer.group_send(self.group_name, { + 'type': 'collab_event', + 'path': event.get('path'), + 'event': event, + }) + + async def collab_event(self, event): + if not self.filter_path(event): + return + + if event.get('id'): + @database_sync_to_async + def get_collab_event(id): + return CollabEvent.objects.get(id=id) + + # Retry fetching event from DB: DB transactions can cause the channels event to arrive before event data is commited to the DB + collab_event = await aretry(lambda: get_collab_event(event['id']), retry_for=CollabEvent.DoesNotExist) + await self.send_json({ + 'type': collab_event.type, + 'path': collab_event.path, + 'client_id': collab_event.client_id, + 'version': collab_event.version, + **collab_event.data, + }) + elif isinstance(event.get('event'), dict): + await self.send_json(event['event']) + + +class CollabUpdateTextMixin: + def get_object_for_update(self, content): + raise NotImplementedError() + + def perform_update_text(self, obj, path, definition, changes): + raise NotImplementedError() + + @database_sync_to_async + @transaction.atomic() + def collab_update_text(self, content): + obj, path, definition = self.get_object_for_update(content) + + version = content['version'] + # TODO: reject updates for versions that are too old + # * check if version is too old and if there are updates in between + # * simple timestamp comparison is not enough, because when there were no updates in between, the version is still valid + # * checking version < note.version is not enough, because of concurrent updates (e.g. old version, update1 succeeds, update2 fails because of updated version) + + + # Rebase updates + over_updates = CollabEvent.objects \ + .filter(related_id=self.related_id) \ + .filter(path=content['path']) \ + .filter(type=CollabEventType.UPDATE_TEXT) \ + .filter(version__gt=version) \ + .order_by('version') + updates, selection = rebase_updates( + updates=[Update.from_dict(u | {'client_id': self.client_id, 'version': version}) for u in content.get('updates', [])], + selection=EditorSelection.from_dict(content['selection']) if content.get('selection') else None, + over=list(itertools.chain(*[[ + Update.from_dict(u | {'client_id': e.client_id, 'version': version}) + for u in e.data.get('updates', [])] for e in over_updates])), + ) + if not updates: + raise ValidationError('No updates') + + # Update in DB + changes = updates[0].changes + for u in updates[1:]: + changes = changes.compose(u.changes) + + with collab_context(prevent_events=True): + obj = self.perform_update_text(obj, path, definition, changes) + + # Update client info + CollabClientInfo.objects \ + .filter(client_id=self.client_id) \ + .update(path=content['path']) + + # Store OT event in DB + return CollabEvent.objects.create( + related_id=self.related_id, + path=content['path'], + type=CollabEventType.UPDATE_TEXT, + created=obj.updated, + version=obj.updated.timestamp(), + client_id=self.client_id, + data={ + 'updates': [u.to_dict() for u in updates], + **({'selection': selection.to_dict()} if selection else {}), + }, + ) + + +class CollabUpdateKeyMixin: + def get_object_for_update(self, content) -> tuple[Any, list[str]]: + raise NotImplementedError() + + def perform_update_key(self, obj, path, definition, value): + raise NotImplementedError() + + @database_sync_to_async + @transaction.atomic() + def collab_update_key(self, content): + obj, path, definition = self.get_object_for_update(content) + + # Update in DB + with collab_context(prevent_events=True): + obj = self.perform_update_key(obj, path, definition, content['value']) + + # Update client info + if content.get('update_awareness', False): + CollabClientInfo.objects \ + .filter(client_id=self.client_id) \ + .update(path=content['path']) + + # Store OT event in DB + return CollabEvent.objects.create( + related_id=self.related_id, + path=content['path'], + type=CollabEventType.UPDATE_KEY, + created=obj.updated, + version=obj.updated.timestamp(), + client_id=self.client_id, + data={ + 'value': content['value'], + }, + ) + + +class CollabUpdateAwarenessMixin: + @database_sync_to_async + def collab_update_awareness(self, content): + path = content.get('path') + + version = content['version'] + + selection = None + if content.get('path') and content.get('selection'): + over_events = CollabEvent.objects \ + .filter(related_id=self.related_id) \ + .filter(path=path) \ + .filter(type=CollabEventType.UPDATE_TEXT) \ + .filter(version__gt=version) \ + .order_by('version') + over_updates = list(itertools.chain(*[[ + Update.from_dict(u | {'client_id': self.client_id, 'version': version}) + for u in e.data.get('updates', [])] for e in over_events])) + version = max([e.version for e in over_updates] + [version]) + + selection = EditorSelection.from_dict(content['selection']) + for u in over_updates: + selection = selection.map(u.changes) + + # Update client info + CollabClientInfo.objects \ + .filter(client_id=self.client_id) \ + .update(path=path) + + return { + 'type': CollabEventType.AWARENESS, + 'path': path, + 'client_id': self.client_id, + **({'selection': selection.to_dict()} if selection else {}), + } + + +class GenericCollabMixin(CollabUpdateKeyMixin, CollabUpdateTextMixin, CollabUpdateAwarenessMixin): + pass diff --git a/api/src/reportcreator_api/utils/text_transformations.py b/api/src/reportcreator_api/pentests/collab/text_transformations.py similarity index 87% rename from api/src/reportcreator_api/utils/text_transformations.py rename to api/src/reportcreator_api/pentests/collab/text_transformations.py index 55e477523..02cec503c 100644 --- a/api/src/reportcreator_api/utils/text_transformations.py +++ b/api/src/reportcreator_api/pentests/collab/text_transformations.py @@ -25,6 +25,8 @@ """ import dataclasses +import itertools +from difflib import SequenceMatcher from typing import Optional from reportcreator_api.utils.utils import get_at @@ -38,8 +40,10 @@ class CollabStr: See https://hsivonen.fi/string-length/ """ - def __init__(self, py_str: str|bytes) -> None: - if isinstance(py_str, bytes): + def __init__(self, py_str) -> None: + if isinstance(py_str, CollabStr): + self.str_bytes = py_str.str_bytes + elif isinstance(py_str, bytes): self.str_bytes = py_str else: self.str_bytes = py_str \ @@ -74,6 +78,23 @@ def __add__(self, other): else: raise TypeError('Invalid argument type') + def __eq__(self, value: object) -> bool: + if isinstance(value, CollabStr): + return self.str_bytes == value.str_bytes + elif isinstance(value, str): + return self.str_bytes == CollabStr(value).str_bytes + else: + return self.str_bytes == value + + def __hash__(self) -> int: + return hash(self.str_bytes) + + def __iter__(self): + return map(lambda b: CollabStr(bytes(b)), itertools.batched(self.str_bytes, 2)) + + def join(self, iterable): + return CollabStr(self.str_bytes.join(s.str_bytes for s in iterable)) + @dataclasses.dataclass class ChangeSet: @@ -104,12 +125,12 @@ def from_dict(cls, changes: list): for i, part in enumerate(changes): if isinstance(part, int): sections.extend([part, -1]) - elif not isinstance(part, list) or len(part) == 0 or not isinstance(part[0], int) or not all(map(lambda e: isinstance(e, str), part[1:])): + elif not isinstance(part, list) or len(part) == 0 or not isinstance(part[0], int) or not all(map(lambda e: isinstance(e, (str, CollabStr)), part[1:])): raise ValueError('Invalid change') else: while len(inserted) <= i: inserted.append(CollabStr('')) # Text.empty - inserted[i] = CollabStr('\n'.join(part[1:])) + inserted[i] = CollabStr('\n').join(map(CollabStr, part[1:])) sections.extend([part[0], len(inserted[i])]) return ChangeSet(sections=sections, inserted=inserted) @@ -225,6 +246,10 @@ def apply(self, doc: str): doc = doc[:from_b] + text + doc[from_b + (to_a - from_a):] return str(doc) + @classmethod + def from_diff(cls, text_before, text_after): + return ChangeSet.from_dict(list(diff_lines(text_before.replace('\r\n', '\n'), text_after.replace('\r\n', '\n')))) + @dataclasses.dataclass class SelectionRange: @@ -504,6 +529,43 @@ def compose_sets(setA: ChangeSet, setB: ChangeSet): b.forward(i_len) +def diff_lines(text_before: str, text_after: str): + lines_before = text_before.splitlines(keepends=True) + lines_after = text_after.splitlines(keepends=True) + + idx_before = 0 + for tag, alo, ahi, blo, bhi in SequenceMatcher(a=lines_before, b=lines_after).get_opcodes(): + # Use CollabStr to calculate indices and lengths to handle unicode characters correctly + a_str = CollabStr(''.join(lines_before[alo:ahi])) + b_str = CollabStr(''.join(lines_after[blo:bhi])) + idx_after = idx_before + len(a_str) + + match tag: + case 'equal': + yield idx_after - idx_before + case 'insert': + yield [0, b_str] + case 'delete': + yield [idx_after - idx_before, ''] + case 'replace': + yield from diff_characters(str(a_str), str(b_str)) + idx_before = idx_after + + +def diff_characters(text_before: str, text_after: str): + idx_before = 0 + # Calculate diff using python strings to not split unicode characters + for tag, alo, ahi, blo, bhi in SequenceMatcher(a=text_before, b=text_after).get_opcodes(): + # Use CollabStr to calculate indices and lengths to handle unicode characters correctly + a_str = CollabStr(text_before[alo:ahi]) + b_str = CollabStr(text_after[blo:bhi]) + idx_after = idx_before + len(a_str) + if tag == 'equal': + yield idx_after - idx_before + else: + yield [idx_after - idx_before, b_str] + + def rebase_updates(updates: list[Update], selection: Optional[EditorSelection], over: list[Update]) -> tuple[list[Update], Optional[EditorSelection]]: """ Rebase and deduplicate an array of client-submitted updates that diff --git a/api/src/reportcreator_api/pentests/consumers.py b/api/src/reportcreator_api/pentests/consumers.py index 6c79ca424..196d33ad4 100644 --- a/api/src/reportcreator_api/pentests/consumers.py +++ b/api/src/reportcreator_api/pentests/consumers.py @@ -1,26 +1,16 @@ -import itertools -import json import logging -from datetime import timedelta -from functools import cached_property from asgiref.sync import async_to_sync from channels.db import database_sync_to_async -from channels.exceptions import DenyConnection, StopConsumer -from channels.generic.websocket import AsyncJsonWebsocketConsumer from channels.layers import get_channel_layer from django.core.exceptions import ValidationError -from django.core.serializers.json import DjangoJSONEncoder from django.db import models, transaction from django.db.models import Prefetch -from django.utils import timezone -from django.utils.crypto import get_random_string -from randomcolor import RandomColor +from reportcreator_api.pentests.collab.consumer_base import GenericCollabMixin, WebsocketConsumerBase from reportcreator_api.pentests.customfields.types import FieldDataType from reportcreator_api.pentests.customfields.utils import get_value_at_path, iterate_fields, set_value_at_path from reportcreator_api.pentests.models import ( - CollabClientInfo, CollabEvent, CollabEventType, PentestFinding, @@ -32,228 +22,12 @@ ) from reportcreator_api.pentests.serializers.notes import ProjectNotebookPageSerializer, UserNotebookPageSerializer from reportcreator_api.pentests.serializers.project import PentestFindingSerializer, ReportSectionSerializer -from reportcreator_api.users.serializers import PentestUserSerializer -from reportcreator_api.utils.elasticapm import elasticapm_capture_websocket_transaction -from reportcreator_api.utils.history import history_context -from reportcreator_api.utils.text_transformations import EditorSelection, Update, rebase_updates -from reportcreator_api.utils.utils import aretry, is_uuid +from reportcreator_api.utils.utils import is_uuid log = logging.getLogger(__name__) -class WebsocketConsumerBase(AsyncJsonWebsocketConsumer): - last_permission_check_time = None - initial_path = None - - async def dispatch(self, message): - try: - if not message.get('type', '').startswith('websocket.') and not await self.check_permission(action='read', skip_on_recent_check=True): - await self.close(code=4443) - return - - with history_context(history_user=self.scope.get('user')): - await super().dispatch(message) - except StopConsumer: - await self.delete_client_info() - raise - except Exception as ex: - await self.delete_client_info() - log.exception(ex) - raise ex - - async def websocket_connect(self, message): - async with elasticapm_capture_websocket_transaction(scope=self.scope, event={'type': 'websocket.connect'}): - # Log connection - user = '' - if self.scope.get('user') and not self.scope['user'].is_anonymous: - user = self.scope['user'].username - logging.info(f'CONNECT {self.scope['path']} (user={user})') - - # Set user.admin_permissions_enabled - if self.scope.get('user') and self.scope.get('session', {}).get('admin_permissions_enabled'): - self.scope['user'].admin_permissions_enabled = True - - with history_context(history_user=self.scope.get('user')): - return await super().websocket_connect(message) - - async def websocket_receive(self, message): - event = await self.decode_json(message.get('text', '{}')) - if event.get('type') == 'ping': - await self.send_json({'type': 'ping'}) - return - - async with elasticapm_capture_websocket_transaction(scope=self.scope, event=event): - if not await self.check_permission(action='write', event=event): - await self.close(code=4443) - return - - try: - with history_context(history_user=self.scope.get('user')): - return await super().websocket_receive(message) - except ValidationError as ex: - await self.send_json({ - 'type': 'error', - 'message': ex.message, - }) - - async def websocket_disconnect(self, message): - try: - return await super().websocket_disconnect(message) - finally: - user = '' - if self.scope.get('user') and not self.scope['user'].is_anonymous: - user = self.scope['user'].username - logging.info(f'DISCONNECT {self.scope['path']} (user={user})') - - async def encode_json(self, content): - return json.dumps(content, cls=DjangoJSONEncoder) - - @property - def group_name(self) -> str: - raise NotImplementedError() - - @cached_property - def client_id(self) -> str: - return self.scope.get('client_id') or f'{self.scope['user'].id}/{get_random_string(8)}' - - @cached_property - def client_color(self) -> str: - return RandomColor(seed=get_random_string(8)).generate(luminosity='bright')[0] - - @database_sync_to_async - def check_permission(self, skip_on_recent_check=False, action=None, **kwargs): - # Skip permission check if it was done recently - if skip_on_recent_check and self.last_permission_check_time and self.last_permission_check_time + timedelta(seconds=60) >= timezone.now(): - return True - - # Check if session is still valid - session = self.scope.get('session') - if not session or not session.session_key or \ - session.expire_date < timezone.now() or \ - not session.exists(session.session_key): - return False - - # Check custom permissions - res = self.has_permission(action=action, **kwargs) - self.last_permission_check_time = timezone.now() - return res - - def has_permission(self, **kwargs): - return True - - @database_sync_to_async - def create_client_info(self): - CollabClientInfo.objects.create( - related_id=self.related_id, - user=self.scope['user'], - client_id=self.client_id, - client_color=self.client_color, - path=self.initial_path, - ) - - @database_sync_to_async - def delete_client_info(self): - CollabClientInfo.objects \ - .filter(client_id=self.client_id) \ - .delete() - - def filter_path(self, qs_or_obj): - return qs_or_obj - - def get_client_infos(self): - clients = CollabClientInfo.objects \ - .filter(related_id=self.related_id) \ - .select_related('user') - clients = self.filter_path(clients) - - return [{ - 'client_id': c.client_id, - 'client_color': c.client_color, - 'user': PentestUserSerializer(c.user).data, - 'path': c.path, - } for c in clients] - - async def get_initial_message(self): - return None - - async def get_connect_message(self): - return { - 'type': CollabEventType.CONNECT, - 'client_id': self.client_id, - 'path': self.initial_path, - 'client': { - 'client_id': self.client_id, - 'client_color': self.client_color, - 'user': PentestUserSerializer(self.scope['user']).data, - }, - } - - async def get_disconnect_message(self): - return { - 'type': CollabEventType.DISCONNECT, - 'client_id': self.client_id, - 'path': self.initial_path, - } - - async def connect(self): - if not await self.check_permission(action='connect'): - raise DenyConnection() - - await super().connect() - await self.create_client_info() - if initial_msg := await self.get_initial_message(): - await self.send_json(initial_msg) - - await self.channel_layer.group_add(self.group_name, self.channel_name) - if connect_msg := await self.get_connect_message(): - await self.send_colllab_event(connect_msg) - - async def disconnect(self, close_code): - await self.channel_layer.group_discard(self.group_name, self.channel_name) - await self.delete_client_info() - if disconnect_msg := await self.get_disconnect_message(): - await self.send_colllab_event(disconnect_msg) - await super().disconnect(close_code) - - async def send_colllab_event(self, event): - if not event: - return - elif isinstance(event, CollabEvent): - await self.channel_layer.group_send(self.group_name, { - 'type': 'collab_event', - 'id': str(event.id), - 'path': event.path, - }) - else: - await self.channel_layer.group_send(self.group_name, { - 'type': 'collab_event', - 'path': event.get('path'), - 'event': event, - }) - - async def collab_event(self, event): - if not self.filter_path(event): - return - - if event.get('id'): - @database_sync_to_async - def get_collab_event(id): - return CollabEvent.objects.get(id=id) - - # Retry fetching event from DB: DB transactions can cause the channels event to arrive before event data is commited to the DB - collab_event = await aretry(lambda: get_collab_event(event['id']), retry_for=CollabEvent.DoesNotExist) - await self.send_json({ - 'type': collab_event.type, - 'path': collab_event.path, - 'client_id': collab_event.client_id, - 'version': collab_event.version, - **collab_event.data, - }) - elif isinstance(event.get('event'), dict): - await self.send_json(event['event']) - - -class NotesConsumerBase(WebsocketConsumerBase): +class NotesConsumerBase(GenericCollabMixin, WebsocketConsumerBase): serializer_class = None initial_path = 'notes' @@ -287,23 +61,22 @@ def get_initial_message(self): } async def receive_json(self, content, **kwargs): - msg_type = content.get('type') - if msg_type == CollabEventType.UPDATE_KEY: - event = await self.collab_update_key(content) - await self.send_colllab_event(event) - elif msg_type == CollabEventType.UPDATE_TEXT: - event = await self.collab_update_text(content) - await self.send_colllab_event(event) - elif msg_type == CollabEventType.AWARENESS: - event = await self.collab_update_awareness(content) - await self.send_colllab_event(event) - else: - raise ValueError(f'Invalid message type: {msg_type}') + event = None + match content.get('type'): + case CollabEventType.UPDATE_KEY: + event = await self.collab_update_key(content) + case CollabEventType.UPDATE_TEXT: + event = await self.collab_update_text(content) + case CollabEventType.AWARENESS: + event = await self.collab_update_awareness(content) + case _: + raise ValueError(f'Invalid message type: {content.get("type")}') + await self.send_colllab_event(event) def filter_path(self, obj_or_qs): if isinstance(obj_or_qs, models.QuerySet): return obj_or_qs.filter(path__startswith='notes') - elif isinstance(obj_or_qs, dict) and obj_or_qs.get('path', '').startswith('notes'): + elif isinstance(obj_or_qs, dict) and (obj_or_qs.get('path') or '').startswith('notes'): return obj_or_qs return None @@ -319,126 +92,27 @@ def get_note_for_update(self, path, valid_paths=None): .first() if not note: raise ValidationError('Invalid path: ID not found') - return note, path_parts[2] - - @database_sync_to_async - @transaction.atomic() - def collab_update_key(self, content): - # Validate path and get note - valid_paths = {k for k, f in self.get_serializer().fields.items() if not f.read_only} - {'title', 'text'} - note, key = self.get_note_for_update(path=content.get('path'), valid_paths=valid_paths) - - # Update in DB - serializer = self.get_serializer(instance=note, data={key: content.get('value')}, partial=True) + return note, path_parts[2:], None + + def get_object_for_update(self, content): + match content.get('type'): + case CollabEventType.UPDATE_KEY: + valid_paths = {k for k, f in self.get_serializer().fields.items() if not f.read_only} - {'title', 'text'} + case CollabEventType.UPDATE_TEXT: + valid_paths=['title', 'text'] + case _: + raise ValidationError('Invalid collab event type') + return self.get_note_for_update(path=content.get('path'), valid_paths=valid_paths) + + def perform_update_key(self, obj, path, definition, value): + serializer = self.get_serializer(instance=obj, data={path[0]: value}, partial=True) serializer.is_valid(raise_exception=True) - with collab_context(prevent_events=True): - note = serializer.save() - - return CollabEvent.objects.create( - related_id=self.related_id, - path=content['path'], - type=CollabEventType.UPDATE_KEY, - created=note.updated, - version=note.updated.timestamp(), - client_id=self.client_id, - data={ - 'value': content['value'], - }, - ) - - @database_sync_to_async - @transaction.atomic() - def collab_update_text(self, content): - # Validate path and get note - if not content.get('updates', []): - raise ValidationError('No updates') - note, key = self.get_note_for_update(path=content.get('path'), valid_paths=['title', 'text']) - - version = content['version'] - # TODO: reject updates for versions that are too old - # * check if version is too old and if there are updates in between - # * simple timestamp comparison is not enough, because when there were no updates in between, the version is still valid - # * checking version < note.version is not enough, because of concurrent updates (e.g. old version, update1 succeeds, update2 fails because of updated version) - - # Rebase updates - over_updates = CollabEvent.objects \ - .filter(related_id=self.related_id) \ - .filter(path=content['path']) \ - .filter(type=CollabEventType.UPDATE_TEXT) \ - .filter(version__gt=version) \ - .order_by('version') - updates, selection = rebase_updates( - updates=[Update.from_dict(u | {'client_id': self.client_id, 'version': version}) for u in content.get('updates', [])], - selection=EditorSelection.from_dict(content['selection']) if content.get('selection') else None, - over=list(itertools.chain(*[[ - Update.from_dict(u | {'client_id': e.client_id, 'version': version}) - for u in e.data.get('updates', [])] for e in over_updates])), - ) - if not updates: - raise ValidationError('No updates') - - # Update in DB - changes = updates[0].changes - for u in updates[1:]: - changes = changes.compose(u.changes) - setattr(note, key, changes.apply(getattr(note, key) or '')) - with collab_context(prevent_events=True): - note.save() - - # Update client info - CollabClientInfo.objects \ - .filter(client_id=self.client_id) \ - .update(path=content['path']) - - # Store OT event in DB - return CollabEvent.objects.create( - related_id=self.related_id, - path=content['path'], - type=CollabEventType.UPDATE_TEXT, - created=note.updated, - version=note.updated.timestamp(), - client_id=self.client_id, - data={ - 'updates': [u.to_dict() for u in updates], - **({'selection': selection.to_dict()} if selection else {}), - }, - ) - - @database_sync_to_async - def collab_update_awareness(self, content): - path = content.get('path') or 'notes' - - version = content['version'] - - selection = None - if content.get('path') and content.get('selection'): - over_events = CollabEvent.objects \ - .filter(related_id=self.related_id) \ - .filter(path=path) \ - .filter(type=CollabEventType.UPDATE_TEXT) \ - .filter(version__gt=version) \ - .order_by('version') - over_updates = list(itertools.chain(*[[ - Update.from_dict(u | {'client_id': self.client_id, 'version': version}) - for u in e.data.get('updates', [])] for e in over_events])) - version = max([e.version for e in over_updates] + [version]) - - selection = EditorSelection.from_dict(content['selection']) - for u in over_updates: - selection = selection.map(u.changes) - - # Update client info - CollabClientInfo.objects \ - .filter(client_id=self.client_id) \ - .update(path=path) - - return { - 'type': CollabEventType.AWARENESS, - 'path': path, - 'client_id': self.client_id, - **({'selection': selection.to_dict()} if selection else {}), - } + return serializer.save() + def perform_update_text(self, obj, path, definition, changes): + setattr(obj, path[0], changes.apply(getattr(obj, path[0]) or '')) + obj.save() + return obj class ProjectNotesConsumer(NotesConsumerBase): @@ -502,7 +176,7 @@ def get_notes_queryset(self): .select_related('parent') -class ProjectReportingConsumer(WebsocketConsumerBase): +class ProjectReportingConsumer(GenericCollabMixin, WebsocketConsumerBase): @property def related_id(self): return self.scope['url_route']['kwargs']['project_pk'] @@ -569,26 +243,23 @@ def get_initial_message(self): } async def receive_json(self, content, **kwargs): - msg_type = content.get('type') - if msg_type == CollabEventType.UPDATE_KEY: - event = await self.collab_update_key(content) - await self.send_colllab_event(event) - elif msg_type == CollabEventType.UPDATE_TEXT: - event = await self.collab_update_text(content) - await self.send_colllab_event(event) - elif msg_type == CollabEventType.CREATE: - event = await self.collab_create(content) - await self.send_colllab_event(event) - elif msg_type == CollabEventType.DELETE: - event = await self.collab_delete(content) - await self.send_colllab_event(event) - elif msg_type == CollabEventType.AWARENESS: - event = await self.collab_update_awareness(content) - await self.send_colllab_event(event) - else: - raise ValueError(f'Invalid message type: {msg_type}') - - def get_object_for_update(self, path): + event = None + match content.get('type'): + case CollabEventType.UPDATE_KEY: + event = await self.collab_update_key(content) + case CollabEventType.UPDATE_TEXT: + event = await self.collab_update_text(content) + case CollabEventType.AWARENESS: + event = await self.collab_update_awareness(content) + case CollabEventType.CREATE: + event = await self.collab_create(content) + case CollabEventType.DELETE: + event = await self.collab_delete(content) + case _: + raise ValueError(f'Invalid message type: {content.get("type")}') + await self.send_colllab_event(event) + + def _get_object_for_update(self, path): if not isinstance(path, str): raise ValidationError('Invalid path') path_parts = tuple(path.split('.')) @@ -625,105 +296,48 @@ def get_object_for_update(self, path): return obj, path_parts[2:], None - @database_sync_to_async - @transaction.atomic() - def collab_update_key(self, content): - # Validate path and get section/finding - obj, path, definition = self.get_object_for_update(content.get('path')) - if definition and definition.type in [FieldDataType.MARKDOWN, FieldDataType.STRING]: - raise ValidationError('collab.update_key is not supported for text fields. Use collab.update_text instead.') + def get_object_for_update(self, content): + obj, path, definition = self._get_object_for_update(content.get('path')) + match content.get('type'): + case CollabEventType.UPDATE_TEXT: + if not definition or definition.type not in [FieldDataType.MARKDOWN, FieldDataType.STRING]: + raise ValidationError('collab.update_text is not supported for non-text fields. Use collab.update_key instead.') + case CollabEventType.UPDATE_KEY: + if definition and definition.type in [FieldDataType.MARKDOWN, FieldDataType.STRING]: + raise ValidationError('collab.update_key is not supported for text fields. Use collab.update_text instead.') + case CollabEventType.CREATE: + if not definition or definition.type != FieldDataType.LIST: + raise ValidationError('collab.create is only supported for list fields') + case CollabEventType.DELETE: + if not definition: + raise ValidationError('collab.delete is only supported for list fields') + case _: + raise ValidationError('Invalid collab event type') + return obj, path, definition + + def perform_update_text(self, obj, path, _definition, changes): + updated_data = obj.data + set_value_at_path(updated_data, path[1:], changes.apply(get_value_at_path(updated_data, path[1:]) or '')) + obj.update_data(updated_data) + obj.save() + return obj + def perform_update_key(self, obj, path, definition, value): # Update data in DB if definition: updated_data = obj.data - set_value_at_path(updated_data, path[1:], content.get('value')) + set_value_at_path(updated_data, path[1:], value) serializer_data = {'data': updated_data} else: - serializer_data = {path[0]: content.get('value')} + serializer_data = {path[0]: value} serializer = (ReportSectionSerializer if isinstance(obj, ReportSection) else PentestFindingSerializer)(instance=obj, data=serializer_data, partial=True) serializer.is_valid(raise_exception=True) - with collab_context(prevent_events=True): - obj = serializer.save() - - if content.get('update_awareness', False): - CollabClientInfo.objects \ - .filter(client_id=self.client_id) \ - .update(path=content['path']) - - return CollabEvent.objects.create( - related_id=self.related_id, - path=content['path'], - type=CollabEventType.UPDATE_KEY, - created=obj.updated, - version=obj.updated.timestamp(), - client_id=self.client_id, - data={ - 'value': content['value'], - }, - ) - - @database_sync_to_async - @transaction.atomic() - def collab_update_text(self, content): - obj, path, definition = self.get_object_for_update(content.get('path')) - if not definition or definition.type not in [FieldDataType.MARKDOWN, FieldDataType.STRING]: - raise ValidationError('collab.update_text is not supported for non-text fields. Use collab.update_key instead.') - - version = content['version'] - # TODO: reject updates for versions that are too old - - # Rebase updates - over_updates = CollabEvent.objects \ - .filter(related_id=self.related_id) \ - .filter(path=content['path']) \ - .filter(type=CollabEventType.UPDATE_TEXT) \ - .filter(version__gt=version) \ - .order_by('version') - updates, selection = rebase_updates( - updates=[Update.from_dict(u | {'client_id': self.client_id, 'version': version}) for u in content.get('updates', [])], - selection=EditorSelection.from_dict(content['selection']) if content.get('selection') else None, - over=list(itertools.chain(*[[ - Update.from_dict(u | {'client_id': e.client_id, 'version': version}) - for u in e.data.get('updates', [])] for e in over_updates])), - ) - if not updates: - raise ValidationError('No updates') - - # Update in DB - changes = updates[0].changes - for u in updates[1:]: - changes = changes.compose(u.changes) - updated_data = obj.data - set_value_at_path(updated_data, path[1:], changes.apply(get_value_at_path(updated_data, path[1:]) or '')) - obj.update_data(updated_data) - with collab_context(prevent_events=True): - obj.save() - - # Update client info - CollabClientInfo.objects \ - .filter(client_id=self.client_id) \ - .update(path=content['path']) - - # Store OT event in DB - return CollabEvent.objects.create( - related_id=self.related_id, - path=content['path'], - type=CollabEventType.UPDATE_TEXT, - created=obj.updated, - version=obj.updated.timestamp(), - client_id=self.client_id, - data={ - 'updates': [u.to_dict() for u in updates], - **({'selection': selection.to_dict()} if selection else {}), - }, - ) + return serializer.save() @database_sync_to_async @transaction.atomic() def collab_create(self, content): - obj, path, definition = self.get_object_for_update(content.get('path')) - if not definition or definition.type != FieldDataType.LIST: - raise ValidationError('collab.create is only supported for list fields') + obj, path, _ = self.get_object_for_update(content) # Update DB updated_data = obj.data @@ -750,14 +364,12 @@ def collab_create(self, content): @database_sync_to_async @transaction.atomic() def collab_delete(self, content): - obj, path, definition = self.get_object_for_update(content.get('path')) - if not definition: - raise ValidationError('collab.delete is only supported for fields') + obj, path, _ = self.get_object_for_update(content) updated_data = obj.data lst = get_value_at_path(updated_data, path[1:-1]) if not isinstance(lst, list): - raise ValidationError('collab.delete is only supported for fields') + raise ValidationError('collab.delete is only supported for list fields') index = int(path[-1][1:-1] if path[-1].startswith('[') and path[-1].endswith(']') else path[-1]) if not (0 <= index < len(lst)): raise ValidationError('Invalid list index') @@ -776,41 +388,6 @@ def collab_delete(self, content): client_id=self.client_id, ) - @database_sync_to_async - def collab_update_awareness(self, content): - path = content.get('path') - - version = content['version'] - - selection = None - if content.get('path') and content.get('selection'): - over_events = CollabEvent.objects \ - .filter(related_id=self.related_id) \ - .filter(path=path) \ - .filter(type=CollabEventType.UPDATE_TEXT) \ - .filter(version__gt=version) \ - .order_by('version') - over_updates = list(itertools.chain(*[[ - Update.from_dict(u | {'client_id': self.client_id, 'version': version}) - for u in e.data.get('updates', [])] for e in over_events])) - version = max([e.version for e in over_updates] + [version]) - - selection = EditorSelection.from_dict(content['selection']) - for u in over_updates: - selection = selection.map(u.changes) - - # Update client info - CollabClientInfo.objects \ - .filter(client_id=self.client_id) \ - .update(path=path) - - return { - 'type': CollabEventType.AWARENESS, - 'path': path, - 'client_id': self.client_id, - **({'selection': selection.to_dict()} if selection else {}), - } - def send_collab_event_project(event: CollabEvent): group_name = f'project_{event.related_id}' diff --git a/api/src/reportcreator_api/pentests/signals.py b/api/src/reportcreator_api/pentests/signals.py index 6572466c6..27870df0a 100644 --- a/api/src/reportcreator_api/pentests/signals.py +++ b/api/src/reportcreator_api/pentests/signals.py @@ -7,6 +7,7 @@ from django.utils import timezone from simple_history.signals import pre_create_historical_record +from reportcreator_api.pentests.collab.text_transformations import ChangeSet, Update from reportcreator_api.pentests.customfields.types import FieldDataType, parse_field_definition from reportcreator_api.pentests.customfields.utils import ( HandleUndefinedFieldsOptions, @@ -409,14 +410,16 @@ def collab_updated(sender, instance, created, *args, **kwargs): 'path': f'notes.{getattr(instance, "note_id", None)}', 'related_id': getattr(instance, 'project_id', None), 'serializer': ProjectNotebookPageSerializer, - 'update_keys': ['title', 'text', 'checked', 'icon_emoji', 'assignee_id'], + 'update_keys': ['checked', 'icon_emoji', 'assignee_id'], + 'update_text': ['title', 'text'], }, UserNotebookPage: { 'send': send_collab_event_user, 'path': f'notes.{getattr(instance, "note_id", None)}', 'related_id': getattr(instance, 'user_id', None), 'serializer': UserNotebookPageSerializer, - 'update_keys': ['title', 'text', 'checked', 'icon_emoji'], + 'update_keys': ['checked', 'icon_emoji'], + 'update_text': ['title', 'text'], }, ReportSection: { 'send': send_collab_event_project, @@ -455,7 +458,8 @@ def collab_updated(sender, instance, created, *args, **kwargs): sorted_instances = instance.user.notes.select_related('parent').all() UserNotebookPageSortListSerializer(sorted_instances, context={'user': instance.user}) \ .send_collab_event(sorted_instances) - elif update_keys := set(instance.changed_fields).intersection(sender_options['update_keys']): + elif update_keys := set(instance.changed_fields).intersection(sender_options['update_keys'] + sender_options.get('update_text', [])): + update_text = set(sender_options.get('update_text', [])) if 'custom_fields' in update_keys: update_keys.discard('custom_fields') updated_lists = set() @@ -479,21 +483,41 @@ def collab_updated(sender, instance, created, *args, **kwargs): # Remove parent: only update leaf nodes update_keys.discard('.'.join(('data',) + path[:-1])) + # Text updates + if new_values[path][2].type in [FieldDataType.MARKDOWN, FieldDataType.STRING] and \ + new_values[path][2].type == old_values[path][2].type and \ + isinstance(new_value, str) and isinstance(old_value, str): + update_text.add(path_str) + for k in update_keys: serialized_data = sender_options['serializer'](instance).data if k not in serialized_data and '.' not in k and k.endswith('_id'): k = k[:-3] - sender_options['send'](CollabEvent.objects.create( - related_id=sender_options['related_id'], - type=CollabEventType.UPDATE_KEY, - path=f"{sender_options['path']}.{k}", - created=instance.updated, - version=instance.updated.timestamp(), - data={ - 'value': get_value_at_path(instance.data, k.split('.')[1:]) if k.startswith('data.') else serialized_data[k], - }, - )) + if k in update_text: + text_before = get_value_at_path(instance.initial['custom_fields'], k.split('.')[1:]) if k.startswith('data.') else instance.initial[k] + text_after = get_value_at_path(instance.data, k.split('.')[1:]) if k.startswith('data.') else serialized_data[k] + sender_options['send'](CollabEvent.objects.create( + related_id=sender_options['related_id'], + type=CollabEventType.UPDATE_TEXT, + path=f"{sender_options['path']}.{k}", + created=instance.updated, + version=instance.updated.timestamp(), + data={ + 'updates': [Update(client_id=None, version=None, changes=ChangeSet.from_diff(text_before, text_after)).to_dict()], + }, + )) + else: + sender_options['send'](CollabEvent.objects.create( + related_id=sender_options['related_id'], + type=CollabEventType.UPDATE_KEY, + path=f"{sender_options['path']}.{k}", + created=instance.updated, + version=instance.updated.timestamp(), + data={ + 'value': get_value_at_path(instance.data, k.split('.')[1:]) if k.startswith('data.') else serialized_data[k], + }, + )) @receiver(signals.post_delete, sender=ProjectNotebookPage) diff --git a/api/src/reportcreator_api/tests/test_collab.py b/api/src/reportcreator_api/tests/test_collab.py index c1798dd3e..40ffe5231 100644 --- a/api/src/reportcreator_api/tests/test_collab.py +++ b/api/src/reportcreator_api/tests/test_collab.py @@ -17,6 +17,14 @@ from reportcreator_api.archive.import_export import export_notes from reportcreator_api.conf.asgi import application +from reportcreator_api.pentests.collab.text_transformations import ( + ChangeSet, + CollabStr, + EditorSelection, + SelectionRange, + Update, + rebase_updates, +) from reportcreator_api.pentests.customfields.utils import ( ensure_defined_structure, get_value_at_path, @@ -31,14 +39,6 @@ ReviewStatus, ) from reportcreator_api.tests.mock import api_client, create_project, create_project_type, create_user, mock_time -from reportcreator_api.utils.text_transformations import ( - ChangeSet, - CollabStr, - EditorSelection, - SelectionRange, - Update, - rebase_updates, -) from reportcreator_api.utils.utils import copy_keys @@ -173,6 +173,43 @@ def test_selection_mapping(self, selection, change, expected): actual = selection.map(change) assert actual == expected + @pytest.mark.parametrize(('text_before', 'text_after'), [ + ('line1\nline2\n', 'line1\nline2\n'), # same text + ('', 'new text'), + ('old text', ''), + ('old text\nline2', 'completely replaced\nwith new content'), + ('line1\nline2\n', 'line1\ninserted\nline2\n'), + ('line1\ndeleted\nline2', 'line1\nline2\n'), + ('line1\nsome characters changed\nline2\n', 'line1\nsome char___ers changed\nline2\n'), + # newline handling + ('line1\nline2\n', 'line1\nchanged\n'), + ('line1\nline2', 'line1\nchanged'), + ('line1\nline2', 'line1\nchanged\n'), + ('line1\nline2\n', 'line1\nchanged'), + ('line1\nline2\n', 'line1\n\n\nchanged\n'), + # unicode handling + ('line1\nline2\n', 'line1 🤦🏼‍♂️ text\nline2\n'), + ('line1 🤦🏼‍♂️ text\nline2\n', 'line1 🤦🏼‍♂️ text\nline2\nline 3 🤦🏼‍♂️'), + ('line1 🤦🏼‍♂️ text\nline2\n', 'line1 text\nline2\n'), + ('line1 text\nline2\n', 'line1 🤦🏼‍♂️\nline2\n'), + ('line1 🤦🏼‍♂️\nline2\n', 'line1 🤷\nline2\n'), + ('line1 🤦🏼‍♂️\nline2\n', 'line1 🤦🏿‍♀️\nline2\n'), + # multiple changes + ('some example text', 's__e ex__ple t__t'), + ('some example text', 's_e new example'), + ('some example text', 's_e text new'), + ]) + def test_diff_to_changeset(self, text_before, text_after): + # Forward change + c1 = ChangeSet.from_diff(text_before, text_after) + assert c1.apply(text_before) == text_after + c1.to_dict() + + # Reverse change + c2 = ChangeSet.from_diff(text_after, text_before) + assert c2.apply(text_after) == text_before + c2.to_dict() + @sync_to_async def create_session(user): @@ -442,7 +479,7 @@ async def setUp(self): def setup_db(): self.user1 = create_user() self.user2 = create_user() - self.project = create_project(members=[self.user1, self.user2], notes_kwargs=[{'checked': None, 'icon_emoji': None, 'text': 'ABC'}]) + self.project = create_project(members=[self.user1, self.user2], notes_kwargs=[{'checked': None, 'icon_emoji': None, 'title': 'ABC', 'text': 'ABC'}]) self.note = self.project.notes.all()[0] self.note_path_prefix = f'notes.{self.note.note_id}' self.api_client1 = api_client(self.user1) @@ -511,20 +548,24 @@ async def test_delete_sync(self): async def test_update_key_sync(self): await sync_to_async(self.api_client1.patch)( path=reverse('projectnotebookpage-detail', kwargs={'project_pk': self.project.id, 'id': self.note.note_id}), - data={'checked': True, 'title': 'updated'}) + data={'checked': True, 'title': 'updated', 'text': 'ABCDEF'}) r1_1 = await self.client1.receive_json_from() r1_2 = await self.client1.receive_json_from() - res1 = {r1_1['path']: r1_1, r1_2['path']: r1_2} + r1_3 = await self.client1.receive_json_from() + res1 = {r1_1['path']: r1_1, r1_2['path']: r1_2, r1_3['path']: r1_3} r2_1 = await self.client2.receive_json_from() r2_2 = await self.client2.receive_json_from() - res2 = {r2_1['path']: r2_1, r2_2['path']: r2_2} + r2_3 = await self.client2.receive_json_from() + res2 = {r2_1['path']: r2_1, r2_2['path']: r2_2, r2_3['path']: r2_3} - for k, v in ({'type': CollabEventType.UPDATE_KEY, 'path': self.note_path_prefix + '.title', 'value': 'updated', 'client_id': None}).items(): - assert res1[self.note_path_prefix + '.title'][k] == res2[self.note_path_prefix + '.title'][k] == v for k, v in ({'type': CollabEventType.UPDATE_KEY, 'path': self.note_path_prefix + '.checked', 'value': True, 'client_id': None}).items(): assert res1[self.note_path_prefix + '.checked'][k] == res2[self.note_path_prefix + '.checked'][k] == v + for k, v in ({'type': CollabEventType.UPDATE_TEXT, 'path': self.note_path_prefix + '.title', 'updates': [{'changes': [[3, 'updated']]}], 'client_id': None}).items(): + assert res1[self.note_path_prefix + '.title'][k] == res2[self.note_path_prefix + '.title'][k] == v + for k, v in ({'type': CollabEventType.UPDATE_TEXT, 'path': self.note_path_prefix + '.text', 'updates': [{'changes': [3, [0, 'DEF']]}], 'client_id': None}).items(): + assert res1[self.note_path_prefix + '.text'][k] == res2[self.note_path_prefix + '.text'][k] == v async def test_sort_sync(self): res = await sync_to_async(self.api_client1.post)( @@ -679,12 +720,12 @@ async def test_update_key(self, obj_type, path, value): value_h = getattr(obj_h, path_parts[0]) if len(path_parts) == 1 else get_value_at_path(obj_h.custom_fields, path_parts[1:]) assert value_h == value_db - @pytest.mark.parametrize(('obj_type', 'path'), [(a,) + b for a, b in itertools.product(['finding', 'section'], [ - ('data.field_string',), - ('data.field_markdown',), - ('data.field_list.[0]',), - ('data.field_list_objects.[0].field_string',), - ])]) + @pytest.mark.parametrize(('obj_type', 'path'), list(itertools.product(['finding', 'section'], [ + 'data.field_string', + 'data.field_markdown', + 'data.field_list.[0]', + 'data.field_list_objects.[0].field_string', + ]))) async def test_update_text(self, obj_type, path): if obj_type == 'section': obj = self.section @@ -751,6 +792,29 @@ async def test_update_key_sync(self, obj_type, path, value): await self.assert_event({'type': CollabEventType.UPDATE_KEY, 'path': f'{path_prefix}.{path}', 'value': value, 'client_id': None}) assert await self.client1.receive_nothing() + @pytest.mark.parametrize(('obj_type', 'path'), list(itertools.product(['finding', 'section'], [ + 'data.field_string', + 'data.field_markdown', + 'data.field_list.[0]', + 'data.field_list_objects.[0].field_string', + ]))) + async def test_update_text_sync(self, obj_type, path): + if obj_type == 'section': + obj = self.section + path_prefix = self.section_path_prefix + elif obj_type == 'finding': + obj = self.finding + path_prefix = self.finding_path_prefix + + updated_data = obj.data + set_value_at_path(updated_data, path.split('.')[1:], 'ABCDEF') + obj.update_data(updated_data) + await obj.asave() + + # Websocket messages sent to clients + await self.assert_event({'type': CollabEventType.UPDATE_TEXT, 'path': f'{path_prefix}.{path}', 'updates': [{'changes': [3, [0, 'DEF']]}], 'client_id': None}) + assert await self.client1.receive_nothing() + async def test_sort_findings_sync(self): res = await sync_to_async(self.api_client1.post)( path=reverse('finding-sort', kwargs={'project_pk': self.project.id}),