Skip to content

Commit

Permalink
CELE-32 feat: Connect join with new enhanced model
Browse files Browse the repository at this point in the history
  • Loading branch information
afonsobspinto committed Aug 1, 2024
1 parent fd626a3 commit 914cea1
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type React from "react";
import {useMemo} from "react";
import {Menu, MenuItem} from "@mui/material";
import {type NeuronGroup, ViewerType} from "../../../models";
import {calculateSplitPositions, isNeuronClass} from "../../../helpers/twoD/twoDHelpers.ts";
import {calculateMeanPosition, calculateSplitPositions, isNeuronClass} from "../../../helpers/twoD/twoDHelpers.ts";
import {useSelectedWorkspace} from "../../../hooks/useSelectedWorkspace.ts";
import {Position} from "cytoscape";

Expand Down Expand Up @@ -108,7 +108,7 @@ const ContextMenu: React.FC<ContextMenuProps> = ({
}).map(neuron => neuron.name);

// Calculate the positions for the individual neurons
const basePosition = workspace.availableNeurons[neuronId].viewerData[ViewerType.Graph]?.position || {
const basePosition = workspace.availableNeurons[neuronId].viewerData[ViewerType.Graph]?.defaultPosition || {
x: 0,
y: 0
};
Expand All @@ -117,7 +117,12 @@ const ContextMenu: React.FC<ContextMenuProps> = ({
// Update the selected neurons with individual neurons
individualNeurons.forEach(neuronName => {
newSelectedNeurons.add(neuronName);
positionUpdates[neuronName] = {position: positions[neuronName], visibility: true};
// Only set the position if it doesn't exist yet
if (!workspace.availableNeurons[neuronName].viewerData[ViewerType.Graph]?.defaultPosition) {
positionUpdates[neuronName] = {position: positions[neuronName], visibility: true};
} else {
positionUpdates[neuronName] = {visibility: true};
}
});

// Remove the corresponding class from the toJoin set
Expand All @@ -131,20 +136,8 @@ const ContextMenu: React.FC<ContextMenuProps> = ({
}
});

workspace.customUpdate(draft => {
// Update the selected neurons
draft.selectedNeurons = newSelectedNeurons;

// Update the positions and visibility for the individual neurons and class neuron
Object.entries(positionUpdates).forEach(([neuronName, update]) => {
if (draft.availableNeurons[neuronName]) {
if (update.position !== undefined) {
draft.availableNeurons[neuronName].viewerData[ViewerType.Graph].position = update.position;
}
draft.availableNeurons[neuronName].viewerData[ViewerType.Graph].visibility = update.visibility;
}
});
});
// Update the selected neurons in the workspace
updateWorkspace(newSelectedNeurons, positionUpdates)

return {split: newSplit, join: newJoin};
});
Expand All @@ -157,16 +150,30 @@ const ContextMenu: React.FC<ContextMenuProps> = ({
const newSplit = new Set(prevState.split);

const newSelectedNeurons = new Set(workspace.selectedNeurons);
const positionUpdates: Record<string, { position?: Position | null; visibility: boolean }> = {};


workspace.selectedNeurons.forEach(neuronId => {
const neuronClass = workspace.availableNeurons[neuronId].nclass;

const individualNeurons = Object.values(workspace.availableNeurons).filter(neuron => neuron.nclass === neuronClass && neuron.name !== neuronClass);
const individualNeuronIds = individualNeurons.map(neuron => neuron.name);

// Calculate and set the class position if not set already
const classPosition = calculateMeanPosition(individualNeuronIds, workspace);

if (!workspace.availableNeurons[neuronClass].viewerData[ViewerType.Graph]?.defaultPosition) {
positionUpdates[neuronClass] = {position: classPosition, visibility: true};
} else {
positionUpdates[neuronClass] = {...positionUpdates[neuronClass], visibility: true};
}
// Remove the individual neurons from the selected neurons and add the class neuron
Object.values(workspace.availableNeurons).forEach(neuron => {
if (neuron.nclass === neuronClass && neuron.name !== neuronClass) {
newSelectedNeurons.delete(neuron.name);
newJoin.add(neuron.name);
}
individualNeuronIds.forEach(neuronName => {
newSelectedNeurons.delete(neuronName);
newJoin.add(neuronName);

// Set individual neurons' visibility to false
positionUpdates[neuronName] = {visibility: false};
});
newSelectedNeurons.add(neuronClass);

Expand All @@ -179,15 +186,30 @@ const ContextMenu: React.FC<ContextMenuProps> = ({
});

// Update the selected neurons in the workspace
workspace.customUpdate(draft => {
draft.selectedNeurons = newSelectedNeurons;
});
updateWorkspace(newSelectedNeurons, positionUpdates)

return {split: newSplit, join: newJoin};
});
onClose();
};

const updateWorkspace = (newSelectedNeurons, positionUpdates) => {
workspace.customUpdate(draft => {
// Update the selected neurons
draft.selectedNeurons = newSelectedNeurons;

// Update the positions and visibility for the individual neurons and class neuron
Object.entries(positionUpdates).forEach(([neuronName, update]) => {
if (draft.availableNeurons[neuronName]) {
if (update.position !== undefined) {
draft.availableNeurons[neuronName].viewerData[ViewerType.Graph].defaultPosition = update.position;
}
draft.availableNeurons[neuronName].viewerData[ViewerType.Graph].visibility = update.visibility;
}
});
});
};

const handleAddToWorkspace = () => {
workspace.customUpdate((draft) => {
workspace.selectedNeurons.forEach((neuronId) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ export const computeGraphDifferences = (
} else {
const neuron = workspace.availableNeurons[nodeId];
const attributes = extractNeuronAttributes(neuron);
const position = neuron.viewerData[ViewerType.Graph]?.position ?? null;
const position = neuron.viewerData[ViewerType.Graph]?.defaultPosition ?? null;
nodesToAdd.push(createNode(nodeId, workspace.selectedNeurons.has(nodeId), attributes, position));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ export const calculateMeanPosition = (nodeIds: string[], workspace: Workspace):

nodeIds.forEach(nodeId => {
const neuron = workspace.availableNeurons[nodeId];
const position = neuron?.viewerData[ViewerType.Graph]?.position;
const position = neuron?.viewerData[ViewerType.Graph]?.defaultPosition;
if (position) {
totalX += position.x;
totalY += position.y;
Expand Down Expand Up @@ -193,7 +193,7 @@ export const updateWorkspaceNeurons2DViewerData = (workspace: Workspace, cy: Cor
cy.nodes().forEach(node => {
const neuronId = node.id();
if (draft.availableNeurons[neuronId]) {
draft.availableNeurons[neuronId].viewerData[ViewerType.Graph].position = { ...node.position() };
draft.availableNeurons[neuronId].viewerData[ViewerType.Graph].defaultPosition = { ...node.position() };
draft.availableNeurons[neuronId].viewerData[ViewerType.Graph].visibility = true;
}
});
Expand Down
2 changes: 1 addition & 1 deletion applications/visualizer/frontend/src/models/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export interface EnhancedNeuron extends Neuron {
}

export interface GraphViewerData {
position: Position | null;
defaultPosition: Position | null;
visibility: boolean;
}

Expand Down
2 changes: 1 addition & 1 deletion applications/visualizer/frontend/src/models/workspace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ export class Workspace {
...neuron,
viewerData: {
[ViewerType.Graph]: {
position: null,
defaultPosition: null,
visibility: false,
},
[ViewerType.ThreeD]: {},
Expand Down

0 comments on commit 914cea1

Please sign in to comment.