Skip to content

Commit

Permalink
Feature/sckan 252 (#211)
Browse files Browse the repository at this point in the history
* SCKAN-252 fix: Rollback graph algorithm changes

* SCKAN-252 chore: Update tests

* SCKAN-252 chore: Add new test to journey calculation algorithm

* SCKAN-252 fix: Update generate_paths algorithm
  • Loading branch information
afonsobspinto authored Feb 6, 2024
1 parent f9eac40 commit 3f615cc
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 37 deletions.
23 changes: 13 additions & 10 deletions backend/composer/services/graph_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,33 @@
def generate_paths(origins, vias, destinations):
paths = []
# Calculate the total number of layers, including origins and destinations
number_of_layers = len(vias) + 2 if vias else 2
destination_layer = max([via.order for via in vias] + [0]) + 2 if vias else 1

# Handle direct connections from origins to destinations
for origin in origins:
for destination in destinations:
# Directly use pre-fetched 'from_entities' without triggering additional queries
if origin in destination.from_entities.all() or (not destination.from_entities.exists() and len(vias) == 0):
for dest_entity in destination.anatomical_entities.all():
paths.append([(origin.name, 0), (dest_entity.name, number_of_layers - 1)])
paths.append([(origin.name, 0), (dest_entity.name, destination_layer)])

# Handle connections involving vias
if vias:
for origin in origins:
# Generate paths through vias for each origin
paths.extend(create_paths_from_origin(origin, vias, destinations, [(origin.name, 0)], number_of_layers))
paths.extend(create_paths_from_origin(origin, vias, destinations, [(origin.name, 0)], destination_layer))

# Remove duplicates from the generated paths
unique_paths = [list(path) for path in set(tuple(path) for path in paths)]

return unique_paths


def create_paths_from_origin(origin, vias, destinations, current_path, number_of_layers):
def create_paths_from_origin(origin, vias, destinations, current_path, destination_layer):
# Base case: if there are no more vias to process
if not vias:
# Generate direct connections from the current path to destinations
return [current_path + [(dest_entity.name, number_of_layers - 1)] for dest in destinations
return [current_path + [(dest_entity.name, destination_layer)] for dest in destinations
for dest_entity in dest.anatomical_entities.all()
if current_path[-1][0] in list(
a.name for a in dest.from_entities.all()) or not dest.from_entities.exists()]
Expand All @@ -42,21 +42,21 @@ def create_paths_from_origin(origin, vias, destinations, current_path, number_of
# This checks if the last node in the current path is one of the nodes that can lead to the current via.
# In other words, it checks if there is a valid connection
# from the last node in the current path to the current via.
if len(current_path) == via_layer and (current_path[-1][0] in list(
a.name for a in current_via.from_entities.all()) or not current_via.from_entities.exists()):
if current_path[-1][0] in list(
a.name for a in current_via.from_entities.all()) or not current_via.from_entities.exists():
for entity in current_via.anatomical_entities.all():
# Build new sub-paths including the current via entity
new_sub_path = current_path + [(entity.name, via_layer)]
# Recursively call to build paths from the next vias
new_paths.extend(
create_paths_from_origin(origin, vias[idx + 1:], destinations, new_sub_path, number_of_layers))
create_paths_from_origin(origin, vias[idx + 1:], destinations, new_sub_path, destination_layer))

# Check for direct connections to destinations from the current via
for dest in destinations:
for dest_entity in dest.anatomical_entities.all():
if entity.name in list(a.name for a in dest.from_entities.all()):
# Add path to destinations directly from the current via
new_paths.append(new_sub_path + [(dest_entity.name, number_of_layers - 1)])
new_paths.append(new_sub_path + [(dest_entity.name, destination_layer)])

return new_paths

Expand Down Expand Up @@ -85,7 +85,10 @@ def consolidate_paths(paths):

paths = consolidated + [paths[i] for i in range(len(paths)) if i not in used_indices]

return [[((node[0].replace(JOURNEY_DELIMITER, ' or '), node[1]) if (node[1] == 0 or path.index(node) == len(path)-1) else (node[0].replace(JOURNEY_DELIMITER, ', '), node[1])) for node in path] for path in paths]
return [[((node[0].replace(JOURNEY_DELIMITER, ' or '), node[1]) if (
node[1] == 0 or path.index(node) == len(path) - 1) else (
node[0].replace(JOURNEY_DELIMITER, ', '), node[1])) for node in path] for path in paths]


def can_merge(path1, path2):
# Ensure paths are of the same length
Expand Down
109 changes: 82 additions & 27 deletions backend/tests/test_journey.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_journey_simple_graph_with_jump(self):

cs.origins.add(origin1, origin2)

via = Via.objects.create(connectivity_statement=cs, order=0)
via = Via.objects.create(connectivity_statement=cs)
via.anatomical_entities.add(via1)
via.from_entities.add(origin1)

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_journey_simple_direct_graph(self):
self.assertTrue(all_paths == expected_paths)

journey_paths = consolidate_paths(all_paths)
expected_journey = [[('Oa, Ob', 0), ('Da', 1)]]
expected_journey = [[('Oa or Ob', 0), ('Da', 1)]]
self.assertTrue(journey_paths == expected_journey)

def test_journey_simple_graph_no_jumps(self):
Expand All @@ -124,7 +124,7 @@ def test_journey_simple_graph_no_jumps(self):
cs.origins.add(origin1, origin2)

# Create Via
via = Via.objects.create(connectivity_statement=cs, order=0)
via = Via.objects.create(connectivity_statement=cs)
via.anatomical_entities.add(via1)
via.from_entities.add(origin1, origin2)

Expand Down Expand Up @@ -154,7 +154,7 @@ def test_journey_simple_graph_no_jumps(self):
expected_paths.sort()
self.assertTrue(all_paths == expected_paths)

expected_journey = [[('Oa, Ob', 0), ('V1a', 1), ('Da', 2)]]
expected_journey = [[('Oa or Ob', 0), ('V1a', 1), ('Da', 2)]]
journey_paths = consolidate_paths(all_paths)
self.assertTrue(journey_paths == expected_journey)

Expand All @@ -180,7 +180,7 @@ def test_journey_multiple_vias_no_jumps(self):
cs.origins.add(origin1, origin2)

# Create Via
via = Via.objects.create(connectivity_statement=cs, order=0)
via = Via.objects.create(connectivity_statement=cs)
via.anatomical_entities.add(via1, via2)
via.from_entities.add(origin1, origin2)

Expand Down Expand Up @@ -213,7 +213,7 @@ def test_journey_multiple_vias_no_jumps(self):
self.assertTrue(all_paths == expected_paths)

expected_journey = [
[('Oa, Ob', 0), ('V1a, V1b', 1), ('Da', 2)]
[('Oa or Ob', 0), ('V1a, V1b', 1), ('Da', 2)]
]
journey_paths = consolidate_paths(all_paths)
self.assertTrue(journey_paths == expected_journey)
Expand Down Expand Up @@ -246,19 +246,19 @@ def test_journey_complex_graph(self):
cs.origins.add(origin_a, origin_b)

# Create Vias
via1 = Via.objects.create(connectivity_statement=cs, order=0)
via1 = Via.objects.create(connectivity_statement=cs)
via1.anatomical_entities.add(via1_a)
via1.from_entities.add(origin_a, origin_b)

via2 = Via.objects.create(connectivity_statement=cs, order=1)
via2 = Via.objects.create(connectivity_statement=cs)
via2.anatomical_entities.add(via2_a, via2_b)
via2.from_entities.add(via1_a)

via3 = Via.objects.create(connectivity_statement=cs, order=2)
via3 = Via.objects.create(connectivity_statement=cs)
via3.anatomical_entities.add(via3_a)
via3.from_entities.add(via2_a, origin_a)

via4 = Via.objects.create(connectivity_statement=cs, order=3)
via4 = Via.objects.create(connectivity_statement=cs)
via4.anatomical_entities.add(via4_a)
via4.from_entities.add(via2_a)

Expand Down Expand Up @@ -293,9 +293,9 @@ def test_journey_complex_graph(self):
expected_paths.sort()
self.assertTrue(all_paths == expected_paths)

expected_journey = [[('Oa, Ob', 0), ('V1a', 1), ('V2a', 2), ('V3a', 3), ('Da', 5)],
[('Oa, Ob', 0), ('V1a', 1), ('V2a', 2), ('V4a', 4), ('Da', 5)],
[('Oa, Ob', 0), ('V1a', 1), ('V2b', 2), ('Da', 5)],
expected_journey = [[('Oa or Ob', 0), ('V1a', 1), ('V2a', 2), ('V3a', 3), ('Da', 5)],
[('Oa or Ob', 0), ('V1a', 1), ('V2a', 2), ('V4a', 4), ('Da', 5)],
[('Oa or Ob', 0), ('V1a', 1), ('V2b', 2), ('Da', 5)],
[('Oa', 0), ('V3a', 3), ('Da', 5)]]
journey_paths = consolidate_paths(all_paths)
self.assertTrue(journey_paths == expected_journey)
Expand All @@ -321,27 +321,27 @@ def test_journey_complex_graph_2(self):
cs.origins.add(origin_a, origin_b)

# Create Vias
via1 = Via.objects.create(connectivity_statement=cs, order=0)
via1 = Via.objects.create(connectivity_statement=cs)
via1.anatomical_entities.add(via1_a)
via1.from_entities.add(origin_a, origin_b)

via2 = Via.objects.create(connectivity_statement=cs, order=1)
via2 = Via.objects.create(connectivity_statement=cs)
via2.anatomical_entities.add(via2_a, via2_b)
via2.from_entities.add(via1_a)

via3 = Via.objects.create(connectivity_statement=cs, order=2)
via3 = Via.objects.create(connectivity_statement=cs)
via3.anatomical_entities.add(via3_a)
via3.from_entities.add(via2_a, via1_a)

via4 = Via.objects.create(connectivity_statement=cs, order=3)
via4 = Via.objects.create(connectivity_statement=cs)
via4.anatomical_entities.add(via4_a)
via4.from_entities.add(via2_b, via3_a)

via5 = Via.objects.create(connectivity_statement=cs, order=4)
via5 = Via.objects.create(connectivity_statement=cs)
via5.anatomical_entities.add(via5_a, via5_b)
via5.from_entities.add(via4_a)

via6 = Via.objects.create(connectivity_statement=cs, order=5)
via6 = Via.objects.create(connectivity_statement=cs)
via6.anatomical_entities.add(via6_a)
via6.from_entities.add(via5_a)

Expand Down Expand Up @@ -382,12 +382,12 @@ def test_journey_complex_graph_2(self):
self.assertTrue(all_paths == expected_paths)

expected_journey = [
[('Oa, Ob', 0), ('V1a', 1), ('V2a', 2), ('V3a', 3), ('V4a', 4), ('V5a', 5), ('V6a', 6), ('Da', 7)],
[('Oa, Ob', 0), ('V1a', 1), ('V2b', 2), ('V4a', 4), ('V5a', 5), ('V6a', 6), ('Da', 7)],
[('Oa, Ob', 0), ('V1a', 1), ('V2a', 2), ('V3a', 3), ('V4a', 4), ('V5b', 5), ('Da', 7)],
[('Oa, Ob', 0), ('V1a', 1), ('V3a', 3), ('V4a', 4), ('V5a', 5), ('V6a', 6), ('Da', 7)],
[('Oa, Ob', 0), ('V1a', 1), ('V2b', 2), ('V4a', 4), ('V5b', 5), ('Da', 7)],
[('Oa, Ob', 0), ('V1a', 1), ('V3a', 3), ('V4a', 4), ('V5b', 5), ('Da', 7)]
[('Oa or Ob', 0), ('V1a', 1), ('V2a', 2), ('V3a', 3), ('V4a', 4), ('V5a', 5), ('V6a', 6), ('Da', 7)],
[('Oa or Ob', 0), ('V1a', 1), ('V2b', 2), ('V4a', 4), ('V5a', 5), ('V6a', 6), ('Da', 7)],
[('Oa or Ob', 0), ('V1a', 1), ('V2a', 2), ('V3a', 3), ('V4a', 4), ('V5b', 5), ('Da', 7)],
[('Oa or Ob', 0), ('V1a', 1), ('V3a', 3), ('V4a', 4), ('V5a', 5), ('V6a', 6), ('Da', 7)],
[('Oa or Ob', 0), ('V1a', 1), ('V2b', 2), ('V4a', 4), ('V5b', 5), ('Da', 7)],
[('Oa or Ob', 0), ('V1a', 1), ('V3a', 3), ('V4a', 4), ('V5b', 5), ('Da', 7)]
]

journey_paths = consolidate_paths(all_paths)
Expand All @@ -413,7 +413,7 @@ def test_journey_cycles(self):
cs.origins.add(origin1, origin2)

# Create Via
via = Via.objects.create(connectivity_statement=cs, order=0)
via = Via.objects.create(connectivity_statement=cs)
via.anatomical_entities.add(origin1)
via.from_entities.add(origin1)

Expand Down Expand Up @@ -446,10 +446,65 @@ def test_journey_cycles(self):

expected_journey = [
[('Oa', 0), ('Oa', 1), ('Da', 2)],
[('Oa, Ob', 0), ('Da', 2)]
[('Oa or Ob', 0), ('Da', 2)]
]

journey_paths = consolidate_paths(all_paths)
expected_journey.sort()
journey_paths.sort()
self.assertTrue(journey_paths == expected_journey)

def test_journey_nonconsecutive_vias(self):
# Test setup
sentence = Sentence.objects.create()
cs = ConnectivityStatement.objects.create(sentence=sentence)

origin1 = AnatomicalEntity.objects.create(name='Oa')
via1 = AnatomicalEntity.objects.create(name='V1a')
via2 = AnatomicalEntity.objects.create(name='V2a')
destination1 = AnatomicalEntity.objects.create(name='Da')

cs.origins.add(origin1)

via_a = Via.objects.create(connectivity_statement=cs)
via_a.anatomical_entities.add(via1)
via_a.from_entities.add(origin1)

via_b = Via.objects.create(connectivity_statement=cs)
via_b.anatomical_entities.add(via2)
via_b.from_entities.add(via1)

# Directly change the order of vias in the database
Via.objects.filter(pk=via_a.pk).update(order=2) # Change to non-zero start
Via.objects.filter(pk=via_b.pk).update(order=5) # Change to non-consecutive

destination = Destination.objects.create(connectivity_statement=cs)
destination.anatomical_entities.add(destination1)
destination.from_entities.add(via2)

# Prefetch related data
origins = list(cs.origins.all())
vias = list(
Via.objects.filter(connectivity_statement=cs).prefetch_related('anatomical_entities', 'from_entities'))
destinations = list(
Destination.objects.filter(connectivity_statement=cs).prefetch_related('anatomical_entities',
'from_entities'))

expected_paths = [
[('Oa', 0), ('V1a', 3), ('V2a', 6), ('Da', 7)],
]

all_paths = generate_paths(origins, vias, destinations)

all_paths.sort()
expected_paths.sort()
self.assertTrue(all_paths == expected_paths, f"Expected paths {expected_paths}, but found {all_paths}")

journey_paths = consolidate_paths(all_paths)
expected_journey = [
[('Oa', 0), ('V1a', 3), ('V2a', 6), ('Da', 7)],
]
journey_paths.sort()
expected_journey.sort()
self.assertTrue(journey_paths == expected_journey,
f"Expected journey {expected_journey}, but found {journey_paths}")

0 comments on commit 3f615cc

Please sign in to comment.