Skip to content

Commit

Permalink
refactor: anistropy not hardcoded (#285)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Dec 23, 2024
1 parent bab7c0b commit 93a0b49
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 70 deletions.
33 changes: 22 additions & 11 deletions src/deep_neurographs/fragments_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
@author: Anna Grim
@email: [email protected]
Implementation of subclass of Networkx.Graph called "FragmentsGraph".
NOTE: SAVE LABEL UPDATES --- THERE IS A BUG IN FEATURE GENERATION
Implementation of subclass of Networkx.Graph called "FragmentsGraph" which is
a graph that is initialized by loading swc files (i.e. fragments) from a
predicted segmentation.
"""
import zipfile
Expand All @@ -30,17 +30,23 @@ class FragmentsGraph(nx.Graph):
"""

def __init__(self, img_bbox=None, node_spacing=1):
def __init__(
self, anisotropy=[1.0, 1.0, 1.0], img_bbox=None, node_spacing=1
):
"""
Initializes an instance of NeuroGraph.
Parameters
----------
anisotropy : ArrayLike, optional
Image to physical coordinates scaling factors to account for the
anisotropy of the microscope. The default is [1.0, 1.0, 1.0].
img_bbox : dict or None, optional
Dictionary with the keys "min" and "max" which specify a bounding
box in an image. The default is None.
node_spacing : int, optional
Spacing (in microns) between nodes. The default is 1.
Physical spacing (in microns) between nodes in swcs. The default
is 1.
Returns
-------
Expand All @@ -49,6 +55,7 @@ def __init__(self, img_bbox=None, node_spacing=1):
"""
super(FragmentsGraph, self).__init__()
# General class attributes
self.anisotropy = anisotropy
self.leaf_kdtree = None
self.node_cnt = 0
self.node_spacing = node_spacing
Expand Down Expand Up @@ -908,8 +915,8 @@ def oriented_edge(self, edge, i, key="xyz"):

def is_contained(self, node_or_xyz, buffer=0):
if self.bbox:
coord = self.to_voxels(node_or_xyz)
return util.is_contained(self.bbox, coord, buffer=buffer)
voxel = self.to_voxels(node_or_xyz, self.anisotropy)
return util.is_contained(self.bbox, voxel, buffer=buffer)
else:
return True

Expand All @@ -921,13 +928,17 @@ def branch_contained(self, xyz_list):
else:
return True

def to_voxels(self, node_or_xyz, shift=False):
def to_voxels(self, node_or_xyz, shift=np.array([0, 0, 0])):
# Get xyz coordinate
shift = self.origin if shift else np.zeros((3))
if type(node_or_xyz) is int:
coord = img_util.to_voxels(self.nodes[node_or_xyz]["xyz"])
xyz = self.nodes[node_or_xyz]["xyz"]
else:
coord = img_util.to_voxels(node_or_xyz)
return coord - shift
xyz = node_or_xyz

# Coordinate conversion
voxel = img_util.to_voxels(xyz, self.anisotropy)
return voxel - shift

def is_leaf(self, i):
"""
Expand Down
16 changes: 5 additions & 11 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def run(self, fragments_pointer):
"""
# Initializations
self.report_experiment()
self.log_experiment()
self.write_metadata()
t0 = time()

Expand All @@ -181,22 +181,16 @@ def run(self, fragments_pointer):
t, unit = util.time_writer(time() - t0)
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")

def run_schedule(
self, fragments_pointer, radius_schedule, save_all_rounds=False
):
def run_schedule(self, fragments_pointer, radius_schedule):
t0 = time()
self.report_experiment()
self.log_experiment()
self.build_graph(fragments_pointer)
for round_id, radius in enumerate(radius_schedule):
self.report(f"--- Round {round_id + 1}: Radius = {radius} ---")
round_id += 1
self.generate_proposals(radius)
self.run_inference()
if save_all_rounds:
self.save_results(round_id=round_id)

if not save_all_rounds:
self.save_results(round_id=round_id)
self.save_results(round_id=round_id)

t, unit = util.time_writer(time() - t0)
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")
Expand Down Expand Up @@ -433,7 +427,7 @@ def report(self, txt):
self.log_handle.write(txt)
self.log_handle.write("\n")

def report_experiment(self):
def log_experiment(self):
self.report("\nExperiment Overview")
self.report("-------------------------------------------------------")
self.report(f"Sample_ID: {self.sample_id}")
Expand Down
62 changes: 32 additions & 30 deletions src/deep_neurographs/machine_learning/feature_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class FeatureGenerator:
def __init__(
self,
img_path,
downsample_factor,
multiscale,
anisotropy=[1.0, 1.0, 1.0],
label_path=None,
is_multimodal=False,
):
Expand All @@ -47,9 +48,11 @@ def __init__(
----------
img_path : str
Path to the raw image assumed to be stored in a GCS bucket.
downsample_factor : int
Downsampling factor that accounts for which level in the image
pyramid the voxel coordinates must index into.
multiscale : int
Level in the image pyramid that voxel coordinates must index into.
anisotropy : ArrayLike, optional
Image to physical coordinates scaling factors to account for the
anisotropy of the microscope. The default is [1.0, 1.0, 1.0].
label_path : str, optional
Path to the segmentation assumed to be stored on a GCS bucket. The
default is None.
Expand All @@ -62,11 +65,12 @@ def __init__(
None
"""
# Initialize instance attributes
self.downsample_factor = downsample_factor
# General instance attributes
self.anisotropy = anisotropy
self.multiscale = multiscale
self.is_multimodal = is_multimodal

# Initialize image-based attributes
# Open images
driver = "n5" if ".n5" in img_path else "zarr"
self.img = img_util.open_tensorstore(img_path, driver=driver)
if label_path:
Expand All @@ -75,24 +79,22 @@ def __init__(
self.labels = None

# Set chunk shapes
self.img_patch_shape = self.set_patch_shape(downsample_factor)
self.img_patch_shape = self.set_patch_shape(multiscale)
self.label_patch_shape = self.set_patch_shape(0)

# Validate embedding requirements
if self.is_multimodal and not label_path:
raise("Must provide labels to generate image embeddings")

@classmethod
def set_patch_shape(cls, downsample_factor):
def set_patch_shape(cls, multiscale):
"""
Adjusts the chunk shape by downsampling each dimension by a specified
factor.
Parameters
----------
downsample_factor : int
The factor by which to downsample each dimension of the current
chunk shape.
None
Returns
-------
Expand All @@ -101,7 +103,7 @@ def set_patch_shape(cls, downsample_factor):
factor.
"""
return [s // 2 ** downsample_factor for s in cls.patch_shape]
return [s // 2 ** multiscale for s in cls.patch_shape]

@classmethod
def get_n_profile_points(cls):
Expand All @@ -114,7 +116,7 @@ def run(self, neurograph, proposals_dict, radius):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
Graph that "proposals" belong to.
proposals_dict : dict
Dictionary that contains the items (1) "proposals" which are the
Expand Down Expand Up @@ -154,7 +156,7 @@ def run_on_nodes(self, neurograph, computation_graph):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
NeuroGraph generated from a predicted segmentation.
computation_graph : networkx.Graph
Graph used by GNN to classify proposals.
Expand All @@ -173,7 +175,7 @@ def run_on_branches(self, neurograph, computation_graph):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
NeuroGraph generated from a predicted segmentation.
computation_graph : networkx.Graph
Graph used by GNN to classify proposals.
Expand All @@ -192,7 +194,7 @@ def run_on_proposals(self, neurograph, proposals, radius):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
NeuroGraph generated from a predicted segmentation.
proposals : list[frozenset]
List of proposals for which features will be generated.
Expand All @@ -219,7 +221,7 @@ def node_skeletal(self, neurograph, computation_graph):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
NeuroGraph generated from a predicted segmentation.
computation_graph : networkx.Graph
Graph used by GNN to classify proposals.
Expand Down Expand Up @@ -248,8 +250,8 @@ def branch_skeletal(self, neurograph, computation_graph):
Parameters
----------
neurograph : NeuroGraph
NeuroGraph generated from a predicted segmentation.
neurograph : FragmentsGraph
Fragments graph that features are to be generated from.
computation_graph : networkx.Graph
Graph used by GNN to classify proposals.
Expand All @@ -275,7 +277,7 @@ def proposal_skeletal(self, neurograph, proposals, radius):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
NeuroGraph generated from a predicted segmentation.
proposals : list[frozenset]
List of proposals for which features will be generated.
Expand Down Expand Up @@ -311,7 +313,7 @@ def node_profiles(self, neurograph, computation_graph):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
NeuroGraph generated from a predicted segmentation.
computation_graph : networkx.Graph
Graph used by GNN to classify proposals.
Expand Down Expand Up @@ -349,7 +351,7 @@ def proposal_profiles(self, neurograph, proposals):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
Graph that "proposals" belong to.
proposals : list[frozenset]
List of proposals for which features will be generated.
Expand Down Expand Up @@ -382,7 +384,7 @@ def proposal_patches(self, neurograph, proposals):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
Graph that "proposals" belong to.
proposals : list[frozenset]
List of proposals for which features will be generated.
Expand Down Expand Up @@ -471,7 +473,7 @@ def transform_path(self, xyz_path):
"""
voxels = np.zeros((len(xyz_path), 3), dtype=int)
for i, xyz in enumerate(xyz_path):
voxels[i] = img_util.to_voxels(xyz, self.downsample_factor)
voxels[i] = img_util.to_voxels(xyz, self.anisotropy, self.multiscale)
return voxels

def get_bbox(self, voxels, is_img=True):
Expand All @@ -486,15 +488,15 @@ def get_bbox(self, voxels, is_img=True):
def get_patch(self, labels, xyz_path, proposal):
# Initializations
center = np.mean(xyz_path, axis=0)
voxels = [img_util.to_voxels(xyz) for xyz in xyz_path]
voxels = [img_util.to_voxels(xyz, self.anisotropy) for xyz in xyz_path]

# Read patches
img_patch = self.read_img_patch(center)
label_patch = self.read_label_patch(voxels, labels)
return {proposal: np.stack([img_patch, label_patch], axis=0)}

def read_img_patch(self, xyz_centroid):
center = img_util.to_voxels(xyz_centroid, self.downsample_factor)
center = img_util.to_voxels(xyz_centroid, self.anisotropy, self.multiscale)
img_patch = img_util.read_tensorstore(
self.img, center, self.img_patch_shape
)
Expand All @@ -509,7 +511,7 @@ def read_label_patch(self, voxels, labels):
def relabel(self, label_patch, voxels, labels):
# Initializations
n_points = self.get_n_profile_points()
scaling_factor = 2 ** self.downsample_factor
scaling_factor = 2 ** self.multiscale
label_patch = zoom(label_patch, 1.0 / scaling_factor, order=0)
for i, voxel in enumerate(voxels):
voxels[i] = [v // scaling_factor for v in voxel]
Expand All @@ -529,7 +531,7 @@ def get_leaf_path(neurograph, i):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
NeuroGraph generated from a predicted segmentation.
i : int
Leaf node in "neurograph".
Expand All @@ -551,7 +553,7 @@ def get_branching_path(neurograph, i):
Parameters
----------
neurograph : NeuroGraph
neurograph : FragmentsGraph
NeuroGraph generated from a predicted segmentation.
i : int
branching node in "neurograph".
Expand Down
10 changes: 9 additions & 1 deletion src/deep_neurographs/machine_learning/heterograph_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,17 @@ def set_edge_attrs(self, x_nodes, edge_type, idx_map):
e1, e2 = self.data[edge_type].edge_index[:, i]
v = node_intersection(idx_map, e1, e2)
if v < 0:
attrs.append(torch.zeros(self.n_branch_features() + 1))
attrs.append(np.zeros(self.n_branch_features() + 1))
else:
attrs.append(x_nodes[v])

#print(edge_type, attrs[0].size())
try:
np.array(attrs)
#print(edge_type, v, attrs)
except:
print(edge_type, v, attrs)
stop
arrs = torch.tensor(np.array(attrs), dtype=DTYPE)
self.data[edge_type].edge_attr = arrs

Expand Down
4 changes: 2 additions & 2 deletions src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
Parameters
----------
anisotropy : list[float], optional
Scaling factors applied to xyz coordinates to account for
Scaling factors applied to xyz coordinates to account for the
anisotropy of microscope. The default is [1.0, 1.0, 1.0].
min_size : float, optional
Minimum path length of swc files which are stored as connected
Expand Down Expand Up @@ -227,7 +227,7 @@ def clip_branches(self, graph, swc_id):
if self.img_bbox:
delete_nodes = set()
for i in graph.nodes:
xyz = img_util.to_voxels(graph.nodes[i]["xyz"])
xyz = img_util.to_voxels(graph.nodes[i]["xyz"], self.to_anisotropy)
if not util.is_contained(self.img_bbox, xyz):
delete_nodes.add(i)
graph.remove_nodes_from(delete_nodes)
Expand Down
Loading

0 comments on commit 93a0b49

Please sign in to comment.