From 6c8363e24b786d41e2d8fd0d2d2398bd38719ef4 Mon Sep 17 00:00:00 2001 From: salam dalloul Date: Thu, 3 Oct 2024 23:59:40 +0200 Subject: [PATCH 1/7] #292 Improve Population Diagram styling --- .../ProofingTab/GraphDiagram/GraphDiagram.tsx | 169 +++++++++++------- 1 file changed, 106 insertions(+), 63 deletions(-) diff --git a/frontend/src/components/ProofingTab/GraphDiagram/GraphDiagram.tsx b/frontend/src/components/ProofingTab/GraphDiagram/GraphDiagram.tsx index 9447fe0f..2a22cfdf 100644 --- a/frontend/src/components/ProofingTab/GraphDiagram/GraphDiagram.tsx +++ b/frontend/src/components/ProofingTab/GraphDiagram/GraphDiagram.tsx @@ -1,10 +1,10 @@ -import React, {useEffect, useRef, useState} from "react"; +import React, {useEffect, useLayoutEffect, useRef, useState} from "react"; import InfoMenu from "./InfoMenu"; import NavigationMenu from "./NavigationMenu"; import createEngine, { - BasePositionModelOptions, + BasePositionModelOptions, DagreEngine, DefaultLinkModel, - DiagramModel, + DiagramModel, PathFindingLinkFactory, } from '@projectstorm/react-diagrams'; import {CanvasWidget} from '@projectstorm/react-canvas-core'; import {CustomNodeModel} from "./Models/CustomNodeModel"; @@ -16,8 +16,7 @@ import { TypeC11Enum, ViaSerializerDetails } from "../../../apiclient/backend"; - - +import {DiagramEngine} from "@projectstorm/react-diagrams-core"; export enum NodeTypes { Origin = 'Origin', Via = 'Via', @@ -78,42 +77,42 @@ const createLink = (sourceNode: CustomNodeModel, targetNode: CustomNodeModel, so }; const processData = ( - origins: AnatomicalEntity[] | undefined, - vias: ViaSerializerDetails[] | undefined, - destinations: DestinationSerializerDetails[] | undefined, - forward_connection: any[], + origins: AnatomicalEntity[] | undefined, + vias: ViaSerializerDetails[] | undefined, + destinations: DestinationSerializerDetails[] | undefined, + forward_connection: any[], ): { nodes: CustomNodeModel[], links: DefaultLinkModel[] } => { const nodes: CustomNodeModel[] = []; const links: DefaultLinkModel[] = []; - + const nodeMap = new Map(); - + const yStart = 50 const yIncrement = 250; // Vertical spacing const xIncrement = 250; // Horizontal spacing let xOrigin = 100 - + origins?.forEach(origin => { const id = getId(NodeTypes.Origin, origin) const name = origin.simple_entity !== null ? origin.simple_entity.name : origin.region_layer.region.name + '(' + origin.region_layer.layer.name + ')'; const ontology_uri = origin.simple_entity !== null ? origin.simple_entity.ontology_uri : origin.region_layer.region.ontology_uri + ', ' + origin.region_layer.layer.ontology_uri; const fws: never[] = [] const originNode = new CustomNodeModel( - NodeTypes.Origin, - name, - ontology_uri, - { - forward_connection: fws, - to: [], - } + NodeTypes.Origin, + name, + ontology_uri, + { + forward_connection: fws, + to: [], + } ); originNode.setPosition(xOrigin, yStart); nodes.push(originNode); nodeMap.set(id, originNode); xOrigin += xIncrement; }); - - + + vias?.forEach((via) => { const layerIndex = via.order + 1 let xVia = 120 @@ -124,21 +123,21 @@ const processData = ( const ontology_uri = entity.simple_entity !== null ? entity.simple_entity.ontology_uri : entity.region_layer.region.ontology_uri + ', ' + entity.region_layer.layer.ontology_uri; const fws: never[] = [] const viaNode = new CustomNodeModel( - NodeTypes.Via, - name, - ontology_uri, - { - forward_connection: fws, - from: [], - to: [], - anatomicalType: via?.type ? ViaTypeMapping[via.type] : '' - } + NodeTypes.Via, + name, + ontology_uri, + { + forward_connection: fws, + from: [], + to: [], + anatomicalType: via?.type ? ViaTypeMapping[via.type] : '' + } ); viaNode.setPosition(xVia, yVia); nodes.push(viaNode); nodeMap.set(id, viaNode); xVia += xIncrement - + via.from_entities.forEach(fromEntity => { const sourceNode = findNodeForEntity(fromEntity, nodeMap, layerIndex - 1); if (sourceNode) { @@ -155,12 +154,12 @@ const processData = ( }); yVia += yIncrement; }); - - + + const yDestination = yIncrement * ((vias?.length || 1) + 1) + yStart let xDestination = 115 - - + + // Process Destinations destinations?.forEach(destination => { destination.anatomical_entities.forEach(entity => { @@ -174,14 +173,14 @@ const processData = ( return false; }); const destinationNode = new CustomNodeModel( - NodeTypes.Destination, - name, - ontology_uri, - { - forward_connection: fws, - from: [], - anatomicalType: destination?.type ? DestinationTypeMapping[destination.type] : '', - } + NodeTypes.Destination, + name, + ontology_uri, + { + forward_connection: fws, + from: [], + anatomicalType: destination?.type ? DestinationTypeMapping[destination.type] : '', + } ); destinationNode.setPosition(xDestination, yDestination); nodes.push(destinationNode); @@ -202,73 +201,117 @@ const processData = ( }); }); }); - + return {nodes, links}; }; + +function genDagreEngine() { + return new DagreEngine({ + graph: { + rankdir: 'TB', + ranksep: 300, + nodesep: 250, + marginx: 50, + marginy: 50 + }, + }); +} +function reroute(engine: DiagramEngine) { + engine.getLinkFactories().getFactory(PathFindingLinkFactory.NAME).calculateRoutingMatrix(); +} +function autoDistribute(engine: DiagramEngine) { + const model = engine.getModel(); + + // Ensure model and nodes exist before proceeding + if (!model || model.getNodes().length === 0) { + return; + } + + const dagreEngine = genDagreEngine(); + dagreEngine.redistribute(model); + + reroute(engine); + engine.repaintCanvas(); +} const GraphDiagram: React.FC = ({origins, vias, destinations, forward_connection = []}) => { const [engine] = useState(() => createEngine()); const [modelUpdated, setModelUpdated] = useState(false) const [modelFitted, setModelFitted] = useState(false) const containerRef = useRef(null); - + // This effect runs once to set up the engine useEffect(() => { engine.getNodeFactories().registerFactory(new CustomNodeFactory()); }, [engine]); - - + + // This effect runs whenever origins, vias, or destinations change useEffect(() => { const {nodes, links} = processData(origins, vias, destinations, forward_connection); - + const model = new DiagramModel(); model.addAll(...nodes, ...links); - + engine.setModel(model); + // engine.getModel().setLocked(true) setModelUpdated(true) }, [origins, vias, destinations, engine, forward_connection]); - + // This effect prevents the default scroll and touchmove behavior useEffect(() => { const currentContainer = containerRef.current; - + if (modelUpdated && currentContainer) { const disableScroll = (event: Event) => { event.stopPropagation(); }; - + currentContainer.addEventListener('wheel', disableScroll, {passive: false}); currentContainer.addEventListener('touchmove', disableScroll, {passive: false}); - + return () => { currentContainer?.removeEventListener('wheel', disableScroll); currentContainer?.removeEventListener('touchmove', disableScroll); }; } + }, [modelUpdated]); - + useEffect(() => { if (modelUpdated && !modelFitted) { // TODO: for unknown reason at the moment if I call zoomToFit too early breaks the graph // To fix later in the next contract. + setTimeout(() => { engine.zoomToFit(); }, 1000); setModelFitted(true); } }, [modelUpdated, modelFitted, engine]); - + + useLayoutEffect(() => { + autoDistribute(engine); + }, [engine, modelUpdated, destinations, vias, origins]); + + useEffect(() => { + const currentContainer = containerRef.current; + + if (modelUpdated && currentContainer) { + autoDistribute(engine); + } + }, [engine, modelUpdated, destinations, vias, origins]); + return ( - modelUpdated ? ( -
- - - -
) - : null + modelUpdated ? ( +
+ + + +
) + : null ); } -export default GraphDiagram; +export default GraphDiagram; \ No newline at end of file From 105d3d64c93f35201201a72b113af82d7d516433 Mon Sep 17 00:00:00 2001 From: salam dalloul Date: Fri, 4 Oct 2024 00:01:54 +0200 Subject: [PATCH 2/7] #292 Improve Population Diagram styling --- .../src/components/ProofingTab/GraphDiagram/GraphDiagram.tsx | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/src/components/ProofingTab/GraphDiagram/GraphDiagram.tsx b/frontend/src/components/ProofingTab/GraphDiagram/GraphDiagram.tsx index 2a22cfdf..5d19878b 100644 --- a/frontend/src/components/ProofingTab/GraphDiagram/GraphDiagram.tsx +++ b/frontend/src/components/ProofingTab/GraphDiagram/GraphDiagram.tsx @@ -223,7 +223,6 @@ function reroute(engine: DiagramEngine) { function autoDistribute(engine: DiagramEngine) { const model = engine.getModel(); - // Ensure model and nodes exist before proceeding if (!model || model.getNodes().length === 0) { return; } From 42bc5376d3d9e867fa528b5cce98f635140d07e2 Mon Sep 17 00:00:00 2001 From: Salam Dalloul Date: Thu, 17 Oct 2024 13:21:50 +0200 Subject: [PATCH 3/7] resolve conflict --- backend/composer/api/serializers.py | 32 +- backend/composer/api/views.py | 35 +- .../0061_graphrenderingstate_and_more.py | 61 +++ backend/composer/models.py | 18 +- backend/composer/services/export_services.py | 298 +++++++------ backend/composer/signals.py | 6 +- backend/composer/templates/admin/index.html | 14 +- frontend/package.json | 2 +- frontend/src/apiclient/backend/api.ts | 31 ++ .../src/components/Forms/StatementForm.tsx | 59 +-- .../ProofingTab/GraphDiagram/GraphDiagram.tsx | 417 ++++++------------ .../GraphDiagram/Models/CustomNodeModel.tsx | 28 ++ .../GraphDiagram/NavigationMenu.tsx | 160 ++++--- .../components/ProofingTab/StatementChart.tsx | 4 +- frontend/src/services/GraphDiagramService.ts | 329 ++++++++++++++ frontend/src/services/StatementService.ts | 7 +- openapi/openapi.yaml | 20 + 17 files changed, 988 insertions(+), 533 deletions(-) create mode 100644 backend/composer/migrations/0061_graphrenderingstate_and_more.py create mode 100644 frontend/src/services/GraphDiagramService.ts diff --git a/backend/composer/api/serializers.py b/backend/composer/api/serializers.py index e97de954..817f0409 100644 --- a/backend/composer/api/serializers.py +++ b/backend/composer/api/serializers.py @@ -3,6 +3,7 @@ from django.contrib.auth.models import User from django.db.models import Q from django_fsm import FSMField +from drf_spectacular.types import OpenApiTypes from drf_writable_nested.mixins import UniqueFieldsMixin from drf_writable_nested.serializers import WritableNestedModelSerializer from rest_framework import serializers @@ -20,7 +21,7 @@ Sentence, Specie, Tag, - Via, Destination, AnatomicalEntityIntersection, Region, Layer, AnatomicalEntityMeta, + Via, Destination, AnatomicalEntityIntersection, Region, Layer, AnatomicalEntityMeta, GraphRenderingState, ) from ..services.connections_service import get_complete_from_entities_for_destination, \ get_complete_from_entities_for_via @@ -495,6 +496,18 @@ class Meta: read_only_fields = ("state",) +class GraphStateSerializer(serializers.ModelSerializer): + class Meta: + model = GraphRenderingState + fields = ['serialized_graph'] + + def to_representation(self, instance): + representation = super().to_representation(instance) + return { + 'serialized_graph': representation['serialized_graph'], + } + + class ConnectivityStatementSerializer(BaseConnectivityStatementSerializer): """Connectivity Statement""" @@ -516,9 +529,10 @@ class ConnectivityStatementSerializer(BaseConnectivityStatementSerializer): ) available_transitions = serializers.SerializerMethodField() journey = serializers.SerializerMethodField() - entities_journey = serializers.SerializerMethodField() + entities_journey = serializers.SerializerMethodField() statement_preview = serializers.SerializerMethodField() errors = serializers.SerializerMethodField() + graph_rendering_state = GraphStateSerializer(required=False, allow_null=True) def get_available_transitions(self, instance) -> list[CSState]: request = self.context.get("request", None) @@ -529,7 +543,7 @@ def get_journey(self, instance): if 'journey' not in self.context: self.context['journey'] = instance.get_journey() return self.context['journey'] - + def get_entities_journey(self, instance): self.context['entities_journey'] = instance.get_entities_journey() return self.context['entities_journey'] @@ -608,12 +622,10 @@ def to_representation(self, instance): return representation def update(self, instance, validated_data): - # Remove 'vias' and 'destinations' from validated_data if they exist + # Remove 'via_set' and 'destinations' from validated_data if they exist validated_data.pop('via_set', None) validated_data.pop('destinations', None) - - # Call the super class's update method with the modified validated_data - return super(ConnectivityStatementSerializer, self).update(instance, validated_data) + return instance class Meta(BaseConnectivityStatementSerializer.Meta): fields = ( @@ -648,7 +660,8 @@ class Meta(BaseConnectivityStatementSerializer.Meta): "modified_date", "has_notes", "statement_preview", - "errors" + "errors", + "graph_rendering_state" ) @@ -690,7 +703,8 @@ class Meta: "modified_date", "has_notes", "statement_preview", - "errors" + "errors", + "graph_rendering_state" ) diff --git a/backend/composer/api/views.py b/backend/composer/api/views.py index 39235ea6..35db88c7 100644 --- a/backend/composer/api/views.py +++ b/backend/composer/api/views.py @@ -53,7 +53,7 @@ Tag, Via, Provenance, - Sex, Destination, + Sex, Destination, GraphRenderingState, ) @@ -336,6 +336,22 @@ def get_queryset(self): return ConnectivityStatement.objects.excluding_draft() return super().get_queryset() + def handle_graph_rendering_state(self, instance, graph_rendering_state_data, user): + if graph_rendering_state_data: + if hasattr(instance, 'graph_rendering_state') and instance.graph_rendering_state is not None: + # Update the existing graph state + instance.graph_rendering_state.serialized_graph = graph_rendering_state_data.get( + 'serialized_graph', instance.graph_rendering_state.serialized_graph) + instance.graph_rendering_state.saved_by = user + instance.graph_rendering_state.save() + else: + # Create a new graph state if none exists + GraphRenderingState.objects.create( + connectivity_statement=instance, + serialized_graph=graph_rendering_state_data.get('serialized_graph', {}), + saved_by=user + ) + @extend_schema( methods=['PUT'], request=ConnectivityStatementUpdateSerializer, @@ -343,12 +359,15 @@ def get_queryset(self): ) def update(self, request, *args, **kwargs): origin_ids = request.data.pop('origins', None) + graph_rendering_state_data = request.data.pop('graph_rendering_state', None) response = super().update(request, *args, **kwargs) - if origin_ids and response.status_code == status.HTTP_200_OK: + if response.status_code == status.HTTP_200_OK: instance = self.get_object() - instance.set_origins(origin_ids) + self.handle_graph_rendering_state(instance, graph_rendering_state_data, request.user) + if origin_ids: + instance.set_origins(origin_ids) return response @@ -358,7 +377,15 @@ def update(self, request, *args, **kwargs): responses={200: ConnectivityStatementSerializer} ) def partial_update(self, request, *args, **kwargs): - return super().partial_update(request, *args, **kwargs) + graph_rendering_state_data = request.data.pop('graph_rendering_state', None) + + response = super().partial_update(request, *args, **kwargs) + + if response.status_code == status.HTTP_200_OK: + instance = self.get_object() + self.handle_graph_rendering_state(instance, graph_rendering_state_data, request.user) + + return response @extend_schema(tags=["public"]) diff --git a/backend/composer/migrations/0061_graphrenderingstate_and_more.py b/backend/composer/migrations/0061_graphrenderingstate_and_more.py new file mode 100644 index 00000000..867963b0 --- /dev/null +++ b/backend/composer/migrations/0061_graphrenderingstate_and_more.py @@ -0,0 +1,61 @@ +# Generated by Django 4.1.4 on 2024-10-07 13:46 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ( + "composer", + "0060_anatomicalentityintersection_unique_layer_region_combination_and_more", + ), + ] + + operations = [ + migrations.CreateModel( + name="GraphRenderingState", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("serialized_graph", models.JSONField()), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ], + ), + migrations.AddConstraint( + model_name="connectivitystatement", + constraint=models.CheckConstraint( + check=models.Q(("projection__in", ["IPSI", "CONTRAT", "BI"])), + name="projection_valid", + ), + ), + migrations.AddField( + model_name="graphrenderingstate", + name="connectivity_statement", + field=models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="graph_rendering_state", + to="composer.connectivitystatement", + ), + ), + migrations.AddField( + model_name="graphrenderingstate", + name="saved_by", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ] diff --git a/backend/composer/models.py b/backend/composer/models.py index 309c5475..66a34fb4 100644 --- a/backend/composer/models.py +++ b/backend/composer/models.py @@ -4,7 +4,6 @@ from django.db.models.expressions import F from django.forms.widgets import Input as InputWidget from django_fsm import FSMField, transition -from django.core.exceptions import ValidationError from composer.services.state_services import ( ConnectivityStatementStateService, @@ -664,7 +663,7 @@ def get_previous_layer_entities(self, via_order): def get_journey(self): return compile_journey(self)['journey'] - + def get_entities_journey(self): entities_journey = compile_journey(self)['entities'] return entities_journey @@ -713,6 +712,20 @@ class Meta: ] +class GraphRenderingState(models.Model): + connectivity_statement = models.OneToOneField( + ConnectivityStatement, + on_delete=models.CASCADE, + related_name='graph_rendering_state', + ) + serialized_graph = models.JSONField() # Stores the serialized diagram model + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + saved_by = models.ForeignKey( + User, on_delete=models.SET_NULL, null=True, blank=True + ) + + class AbstractConnectionLayer(models.Model): connectivity_statement = models.ForeignKey(ConnectivityStatement, on_delete=models.CASCADE) anatomical_entities = models.ManyToManyField(AnatomicalEntity, blank=True) @@ -774,6 +787,7 @@ def save(self, *args, **kwargs): if old_via.order != self.order: self._update_order_for_other_vias(old_via.order) self.from_entities.clear() + super(Via, self).save(*args, **kwargs) def _update_order_for_other_vias(self, old_order): diff --git a/backend/composer/services/export_services.py b/backend/composer/services/export_services.py index ffc7a96c..d90c6f9d 100644 --- a/backend/composer/services/export_services.py +++ b/backend/composer/services/export_services.py @@ -7,13 +7,21 @@ from django.contrib.auth.models import User from django.db import transaction -from django.db.models import Count, QuerySet +from django.db.models import Count, QuerySet, Prefetch from django.utils import timezone -from composer.enums import CSState -from composer.enums import NoteType, ExportRelationships, CircuitType, Laterality, MetricEntity, DestinationType, \ - ViaType, SentenceState, \ - Projection +from composer.enums import ( + CSState, + NoteType, + ExportRelationships, + CircuitType, + Laterality, + MetricEntity, + DestinationType, + ViaType, + SentenceState, + Projection, +) from composer.exceptions import UnexportableConnectivityStatement from composer.models import ( Tag, @@ -22,10 +30,15 @@ ExportMetrics, Sentence, Specie, - Via, AnatomicalEntity, Destination, + Via, + AnatomicalEntity, + Destination, + Note, +) +from composer.services.connections_service import ( + get_complete_from_entities_for_destination, + get_complete_from_entities_for_via, ) -from composer.services.connections_service import get_complete_from_entities_for_destination, \ - get_complete_from_entities_for_via from composer.services.filesystem_service import create_dir_if_not_exists from composer.services.state_services import ConnectivityStatementStateService @@ -61,16 +74,16 @@ class Row: def __init__( - self, - structure: str, - identifier: str, - relationship: str, - predicate: str, - curation_notes: str = "", - review_notes: str = "", - layer: str = "", - connected_from_names: str = "", - connected_from_uris: str = "" + self, + structure: str, + identifier: str, + relationship: str, + predicate: str, + curation_notes: str = "", + review_notes: str = "", + layer: str = "", + connected_from_names: str = "", + connected_from_uris: str = "", ): self.structure = structure self.identifier = identifier @@ -96,14 +109,13 @@ def get_nlp_id(cs: ConnectivityStatement, row: Row): def get_neuron_population_label(cs: ConnectivityStatement, row: Row): - return ' '.join(cs.get_journey()) + return " ".join(cs.get_journey()) def get_type(cs: ConnectivityStatement, row: Row): return cs.phenotype.name if cs.phenotype else "" - def get_structure(cs: ConnectivityStatement, row: Row): return row.structure @@ -133,7 +145,7 @@ def get_predicate(cs: ConnectivityStatement, row: Row): def get_observed_in_species(cs: ConnectivityStatement, row: Row): - return ", ".join([specie.name for specie in cs.species.all()]) + return ", ".join(specie.name for specie in cs.species.all()) def escape_newlines(value): @@ -141,21 +153,22 @@ def escape_newlines(value): def get_different_from_existing(cs: ConnectivityStatement, row: Row): - return escape_newlines( - "\n".join([note.note for note in cs.notes.filter(type=NoteType.DIFFERENT)]) - ) + different_notes = [ + note.note for note in cs.prefetched_notes if note.type == NoteType.DIFFERENT + ] + return escape_newlines("\n".join(different_notes)) def get_curation_notes(cs: ConnectivityStatement, row: Row): - return escape_newlines(row.curation_notes.replace("\\", "\\\\")) + return escape_newlines(row.curation_notes) def get_review_notes(cs: ConnectivityStatement, row: Row): - return escape_newlines(row.review_notes.replace("\\", "\\\\")) + return escape_newlines(row.review_notes) def get_reference(cs: ConnectivityStatement, row: Row): - return ", ".join([procenance.uri for procenance in cs.provenance_set.all()]) + return ", ".join(procenance.uri for procenance in cs.provenance_set.all()) def is_approved_by_sawg(cs: ConnectivityStatement, row: Row): @@ -171,12 +184,12 @@ def get_added_to_sckan_timestamp(cs: ConnectivityStatement, row: Row): def has_nerve_branches(cs: ConnectivityStatement, row: Row) -> bool: - return cs.tags.filter(tag=HAS_NERVE_BRANCHES_TAG).exists() + return any(tag.tag == HAS_NERVE_BRANCHES_TAG for tag in cs.prefetched_tags) def get_tag_filter(tag_name): def tag_filter(cs, row): - return cs.tags.filter(tag=tag_name).exists() + return any(tag.tag == tag_name for tag in cs.prefetched_tags) return tag_filter @@ -203,7 +216,7 @@ def generate_csv_attributes_mapping() -> Dict[str, Callable]: "Review notes": get_review_notes, "Proposed action": get_proposed_action, "Added to SCKAN (time stamp)": get_added_to_sckan_timestamp, - 'URI': get_statement_uri, + "URI": get_statement_uri, } exportable_tags = Tag.objects.filter(exportable=True) for tag in exportable_tags: @@ -220,13 +233,14 @@ def get_origin_row(origin: AnatomicalEntity, review_notes: str, curation_notes: ExportRelationships.hasSomaLocatedIn.value, curation_notes, review_notes, - layer='1' + layer="1", ) def get_destination_row(destination: Destination, total_vias: int): - if destination.from_entities.exists(): - connected_from_entities = destination.from_entities.all() + from_entities = list(destination.from_entities.all()) + if from_entities: + connected_from_entities = from_entities else: connected_from_entities = get_complete_from_entities_for_destination(destination) @@ -243,15 +257,16 @@ def get_destination_row(destination: Destination, total_vias: int): "", layer=layer_value, connected_from_names=connected_from_names, - connected_from_uris=connected_from_uris + connected_from_uris=connected_from_uris, ) for ae in destination.anatomical_entities.all() ] def get_via_row(via: Via): - if via.from_entities.exists(): - connected_from_entities = via.from_entities.all() + from_entities = list(via.from_entities.all()) + if from_entities: + connected_from_entities = from_entities else: connected_from_entities = get_complete_from_entities_for_via(via) @@ -268,7 +283,7 @@ def get_via_row(via: Via): "", layer=layer_value, connected_from_names=connected_from_names, - connected_from_uris=connected_from_uris + connected_from_uris=connected_from_uris, ) for ae in via.anatomical_entities.all() ] @@ -276,8 +291,8 @@ def get_via_row(via: Via): def _get_connected_from_info(entities): connected_from_info = [(entity.name, entity.ontology_uri) for entity in entities] if entities else [] - connected_from_names = '; '.join(name for name, _ in connected_from_info) - connected_from_uris = '; '.join(uri for _, uri in connected_from_info) + connected_from_names = "; ".join(name for name, _ in connected_from_info) + connected_from_uris = "; ".join(uri for _, uri in connected_from_info) return connected_from_names, connected_from_uris @@ -314,7 +329,7 @@ def get_circuit_role_row(cs: ConnectivityStatement): ) -def get_laterality_row(cs: ConnectivityStatement): +def get_projection_row(cs: ConnectivityStatement): return Row( cs.get_projection_display(), TEMP_PROJECTION_MAP.get(cs.projection, ""), @@ -350,7 +365,7 @@ def get_phenotype_row(cs: ConnectivityStatement): def get_projection_phenotype_row(cs: ConnectivityStatement): - projection_phenotype = cs.projection_phenotype if cs.projection_phenotype else "" + projection_phenotype = cs.projection_phenotype.name if cs.projection_phenotype else "" projection_phenotype_ontology_uri = cs.projection_phenotype.ontology_uri if cs.projection_phenotype else "" return Row( @@ -365,7 +380,7 @@ def get_projection_phenotype_row(cs: ConnectivityStatement): def get_functional_circuit_row(cs: ConnectivityStatement): return Row( - cs.functional_circuit_role, + cs.functional_circuit_role.name, cs.functional_circuit_role.ontology_uri, ExportRelationships.hasFunctionalCircuitRolePhenotype.label, ExportRelationships.hasFunctionalCircuitRolePhenotype.value, @@ -381,93 +396,80 @@ def get_forward_connection_row(forward_conn: ConnectivityStatement): ExportRelationships.hasForwardConnection.label, ExportRelationships.hasForwardConnection.value, "", - "" + "", ) -def get_rows(cs: ConnectivityStatement) -> List: +def get_rows(cs: ConnectivityStatement) -> List[Row]: rows = [] - review_notes = "\n".join( - [note.note for note in cs.notes.filter(type=NoteType.PLAIN)] + # Use prefetched notes + plain_notes = [ + note.note for note in cs.prefetched_notes if note.type == NoteType.PLAIN + ] + review_notes = "\n".join(plain_notes) + curation_notes = "\n".join( + note.note for note in cs.sentence.prefetched_sentence_notes ) - curation_notes = "\n".join([note.note for note in cs.sentence.notes.all()]) - for origin in cs.origins.all(): - try: - origin_row = get_origin_row(origin, review_notes, curation_notes) - rows.append(origin_row) - except Exception: - raise UnexportableConnectivityStatement("Error getting origin row") - - for via in cs.via_set.all().order_by("order"): - try: - via_rows = get_via_row(via) - rows.extend(via_rows) - except Exception: - raise UnexportableConnectivityStatement("Error getting via row") - - total_vias = cs.via_set.count() - for destination in cs.destinations.all(): - try: - destination_rows = get_destination_row(destination, total_vias) - rows.extend(destination_rows) - except Exception: - raise UnexportableConnectivityStatement("Error getting destination row") + # Origins + origins = cs.origins.all() + for origin in origins: + origin_row = get_origin_row(origin, review_notes, curation_notes) + rows.append(origin_row) + + # Vias (ordered by 'order' attribute) + vias = cs.via_set.all().order_by("order") + total_vias = vias.count() + for via in vias: + via_rows = get_via_row(via) + rows.extend(via_rows) + + # Destinations + destinations = cs.destinations.all() + for destination in destinations: + destination_rows = get_destination_row(destination, total_vias) + rows.extend(destination_rows) + + # Species for specie in cs.species.all(): - try: - rows.append(get_specie_row(specie)) - except Exception: - raise UnexportableConnectivityStatement("Error getting specie row") + rows.append(get_specie_row(specie)) + # Sex if cs.sex is not None: - try: - rows.append(get_sex_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting sex row") + rows.append(get_sex_row(cs)) - try: + # Circuit Role + if cs.circuit_type is not None: rows.append(get_circuit_role_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting circuit type row") - try: - rows.append(get_laterality_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting laterality row") + # Projection + if cs.projection is not None: + rows.append(get_projection_row(cs)) - try: + # Soma Phenotype + if cs.laterality is not None: rows.append(get_soma_phenotype_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting soma phenotype row") - try: + # Phenotype + if cs.phenotype is not None: rows.append(get_phenotype_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting phenotype row") + # Projection Phenotype if cs.projection_phenotype: - try: - rows.append(get_projection_phenotype_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting projection phenotype row") + rows.append(get_projection_phenotype_row(cs)) + # Functional Circuit Role if cs.functional_circuit_role: - try: - rows.append(get_functional_circuit_row(cs)) - except Exception: - raise UnexportableConnectivityStatement("Error getting functinal circuit role row") + rows.append(get_functional_circuit_row(cs)) + # Forward Connections for forward_conn in cs.forward_connection.all(): - try: - rows.append(get_forward_connection_row(forward_conn)) - except Exception: - raise UnexportableConnectivityStatement("Error getting forward connection row") + rows.append(get_forward_connection_row(forward_conn)) return rows def create_export_batch(qs: QuerySet, user: User) -> ExportBatch: - # do transition to EXPORTED state export_batch = ExportBatch.objects.create(user=user) export_batch.connectivity_statements.set(qs) export_batch.save() @@ -475,14 +477,15 @@ def create_export_batch(qs: QuerySet, user: User) -> ExportBatch: def compute_metrics(export_batch: ExportBatch): - # will be executed by post_save signal on ExportBatch - last_export_batch = ExportBatch.objects.exclude(id=export_batch.id).order_by("-created_at").first() + last_export_batch = ( + ExportBatch.objects.exclude(id=export_batch.id).order_by("-created_at").first() + ) if last_export_batch: last_export_batch_created_at = last_export_batch.created_at else: last_export_batch_created_at = None - # compute the metrics for this export + # Compute the metrics for this export if last_export_batch_created_at: sentences_created_qs = Sentence.objects.filter( created_date__gt=last_export_batch_created_at, @@ -497,18 +500,20 @@ def compute_metrics(export_batch: ExportBatch): ) else: connectivity_statements_created_qs = ConnectivityStatement.objects.all() - connectivity_statements_created_qs.exclude(state=CSState.DRAFT) # skip draft statements + connectivity_statements_created_qs = connectivity_statements_created_qs.exclude( + state=CSState.DRAFT + ) # skip draft statements export_batch.connectivity_statements_created = connectivity_statements_created_qs.count() - # export_batch.save() - - # compute the state metrics for this export - connectivity_statement_metrics = list(ConnectivityStatement.objects.values("state").annotate(count=Count("state"))) + # Compute the state metrics for this export + connectivity_statement_metrics = list( + ConnectivityStatement.objects.values("state").annotate(count=Count("state")) + ) for state in CSState: - try: - metric = [x for x in connectivity_statement_metrics if x.get("state") == state][0] - except IndexError: - metric = {"state": state.value, "count": 0} + metric = next( + (x for x in connectivity_statement_metrics if x.get("state") == state), + {"state": state.value, "count": 0}, + ) ExportMetrics.objects.create( export_batch=export_batch, entity=MetricEntity.CONNECTIVITY_STATEMENT, @@ -517,23 +522,23 @@ def compute_metrics(export_batch: ExportBatch): ) sentence_metrics = list(Sentence.objects.values("state").annotate(count=Count("state"))) for state in SentenceState: - try: - metric = [x for x in sentence_metrics if x.get("state") == state][0] - except IndexError: - metric = {"state": state.value, "count": 0} + metric = next( + (x for x in sentence_metrics if x.get("state") == state), + {"state": state.value, "count": 0}, + ) ExportMetrics.objects.create( export_batch=export_batch, entity=MetricEntity.SENTENCE, state=SentenceState(metric["state"]), count=metric["count"], ) - # ExportMetrics return export_batch def do_transition_to_exported(export_batch: ExportBatch, user: User): system_user = User.objects.get(username="system") - for connectivity_statement in export_batch.connectivity_statements.all(): + connectivity_statements = export_batch.connectivity_statements.all() + for connectivity_statement in connectivity_statements: available_transitions = [ available_state.target for available_state in connectivity_statement.get_available_user_state_transitions( @@ -541,7 +546,6 @@ def do_transition_to_exported(export_batch: ExportBatch, user: User): ) ] if CSState.EXPORTED in available_transitions: - # we need to update the state to exported when we are in the NP0 approved state and the system user has the permission to do so cs = ConnectivityStatementStateService(connectivity_statement).do_transition( CSState.EXPORTED, system_user, user ) @@ -549,7 +553,6 @@ def do_transition_to_exported(export_batch: ExportBatch, user: User): def dump_export_batch(export_batch, folder_path: typing.Optional[str] = None) -> str: - # returns the path of the exported file if folder_path is None: folder_path = tempfile.gettempdir() @@ -560,36 +563,65 @@ def dump_export_batch(export_batch, folder_path: typing.Optional[str] = None) -> csv_attributes_mapping = generate_csv_attributes_mapping() + # Prefetch related data with filters + notes_prefetch = Prefetch( + "notes", + queryset=Note.objects.filter(type__in=[NoteType.PLAIN, NoteType.DIFFERENT]), + to_attr="prefetched_notes", + ) + sentence_notes_prefetch = Prefetch( + "sentence__notes", + queryset=Note.objects.all(), + to_attr="prefetched_sentence_notes", + ) + tags_prefetch = Prefetch( + "tags", queryset=Tag.objects.all(), to_attr="prefetched_tags" + ) + + connectivity_statements = export_batch.connectivity_statements.select_related( + "sentence", "sex", "functional_circuit_role", "projection_phenotype" + ).prefetch_related( + "origins", + notes_prefetch, + tags_prefetch, + "species", + "forward_connection", + "provenance_set", + sentence_notes_prefetch, + "via_set__anatomical_entities", + "via_set__from_entities", + "destinations__anatomical_entities", + "destinations__from_entities", + ) + with open(filepath, "w", newline="") as csvfile: writer = csv.writer(csvfile) - # Write header row headers = csv_attributes_mapping.keys() writer.writerow(headers) - # Write data rows - for obj in export_batch.connectivity_statements.all(): + for cs in connectivity_statements: try: - rows = get_rows(obj) + rows = get_rows(cs) except UnexportableConnectivityStatement as e: logging.warning( - f"Connectivity Statement with id {obj.id} skipped due to {e}" + f"Connectivity Statement with id {cs.id} skipped due to {e}" ) continue + for row in rows: - row_content = [] - for key in csv_attributes_mapping: - row_content.append(csv_attributes_mapping[key](obj, row)) + row_content = [func(cs, row) for func in csv_attributes_mapping.values()] writer.writerow(row_content) + return filepath def export_connectivity_statements( - qs: QuerySet, user: User, folder_path: typing.Optional[str] + qs: QuerySet, user: User, folder_path: typing.Optional[str] ) -> typing.Tuple[str, ExportBatch]: with transaction.atomic(): - # make sure create_export_batch and do_transition_to_exported are in one database transaction + # Ensure create_export_batch and do_transition_to_exported are in one database transaction export_batch = create_export_batch(qs, user) do_transition_to_exported(export_batch, user) export_file = dump_export_batch(export_batch, folder_path) - return export_file, export_batch + return export_file, export_batch \ No newline at end of file diff --git a/backend/composer/signals.py b/backend/composer/signals.py index f0a40a8e..5b559cba 100644 --- a/backend/composer/signals.py +++ b/backend/composer/signals.py @@ -5,8 +5,8 @@ from django_fsm.signals import post_transition from .enums import CSState, NoteType -from .models import ConnectivityStatement, ExportBatch, Note, Sentence, Synonym, \ - AnatomicalEntity, Layer, Region, AnatomicalEntityMeta +from .models import ConnectivityStatement, ExportBatch, Note, Sentence, \ + AnatomicalEntity, Layer, Region from .services.export_services import compute_metrics, ConnectivityStatementStateService @@ -69,4 +69,4 @@ def delete_associated_entities(sender, instance, **kwargs): # Delete the associated region_layer if it exists if instance.region_layer: - instance.region_layer.delete() \ No newline at end of file + instance.region_layer.delete() diff --git a/backend/composer/templates/admin/index.html b/backend/composer/templates/admin/index.html index 6b3952b3..73399f4f 100644 --- a/backend/composer/templates/admin/index.html +++ b/backend/composer/templates/admin/index.html @@ -30,12 +30,16 @@
Export statistics
- Create new export - + + Create new export + +