diff --git a/foa3d/frangi.py b/foa3d/frangi.py index f620780..356e3ec 100644 --- a/foa3d/frangi.py +++ b/foa3d/frangi.py @@ -304,30 +304,31 @@ def frangi_filter(img, scales_px=1, alpha=0.001, beta=1.0, gamma=None, dark=True enhanced_img, fiber_vec, eigenval \ = compute_scaled_orientation(scales_px[0], img, alpha=alpha, beta=beta, gamma=gamma, dark=dark) - # parallel scaled vesselness analysis + # sequential scaled vesselness analysis else: - with Parallel(n_jobs=n_scales, backend='threading', max_nbytes=None) as parallel: - par_res = \ - parallel( - delayed(compute_scaled_orientation)( - scales_px[i], img=img, - alpha=alpha, beta=beta, gamma=gamma, dark=dark) for i in range(n_scales)) - - # unpack and stack results - enhanced_img_tpl, eigenvectors_tpl, eigenvalues_tpl = zip(*par_res) - eigenval = np.stack(eigenvalues_tpl, axis=0) - eigenvec = np.stack(eigenvectors_tpl, axis=0) - enhanced_img = np.stack(enhanced_img_tpl, axis=0) - - # get max scale-wise vesselness - best_idx = np.argmax(enhanced_img, axis=0) - best_idx = np.expand_dims(best_idx, axis=0) - enhanced_img = np.take_along_axis(enhanced_img, best_idx, axis=0).squeeze(axis=0) - - # select fiber orientation vectors (and the associated eigenvalues) among different scales - best_idx = np.expand_dims(best_idx, axis=-1) - eigenval = np.take_along_axis(eigenval, best_idx, axis=0).squeeze(axis=0) - fiber_vec = np.take_along_axis(eigenvec, best_idx, axis=0).squeeze(axis=0) + + # initialize output arrays + scl_shp = (n_scales,) + img.shape + enhanced_img = np.zeros(scl_shp) + + eig_shp = scl_shp + img.ndim + eigenval = np.zeros(eig_shp) + eigenvec = np.zeros(eig_shp) + + # iterate over scales + for s in range(n_scales): + enhanced_img[s], eigenvec[s], eigenval[s] \ + = compute_scaled_orientation(scales_px[s], img=img, alpha=alpha, beta=beta, gamma=gamma, dark=dark) + + # get max scale-wise vesselness + best_idx = np.argmax(enhanced_img, axis=0) + best_idx = np.expand_dims(best_idx, axis=0) + enhanced_img = np.take_along_axis(enhanced_img, best_idx, axis=0).squeeze(axis=0) + + # select fiber orientation vectors (and the associated eigenvalues) among different scales + best_idx = np.expand_dims(best_idx, axis=-1) + eigenval = np.take_along_axis(eigenval, best_idx, axis=0).squeeze(axis=0) + fiber_vec = np.take_along_axis(eigenvec, best_idx, axis=0).squeeze(axis=0) return enhanced_img, fiber_vec, eigenval diff --git a/foa3d/pipeline.py b/foa3d/pipeline.py index 372aaf6..966da96 100644 --- a/foa3d/pipeline.py +++ b/foa3d/pipeline.py @@ -717,8 +717,7 @@ def parallel_frangi_on_slices(img, cli_args, save_dir, tmp_dir, img_name, get_image_info(img, px_size, mask_lpf, ch_mye, is_tiled=is_tiled) # configure batch of basic image slices analyzed in parallel - batch_size, max_slice_size = \ - config_frangi_batch(frangi_sigma_um, ram=ram, jobs=jobs) + batch_size, max_slice_size = config_frangi_batch(ram=ram, jobs=jobs) # get info on the processed image slices rng_in_lst, rng_in_lpf_lst, rng_out_lst, pad_lst, \ diff --git a/foa3d/slicing.py b/foa3d/slicing.py index 360ed8f..3b8a330 100644 --- a/foa3d/slicing.py +++ b/foa3d/slicing.py @@ -222,17 +222,13 @@ def compute_overlap_range(smooth_sigma, frangi_sigma, px_rsz_ratio, truncate=2): return ovlp -def config_frangi_batch(frangi_scales, mem_growth_factor=149.7, mem_fudge_factor=1.0, - min_slice_size=-1, jobs=None, ram=None): +def config_frangi_batch(mem_growth_factor=149.7, mem_fudge_factor=1.0, min_slice_size=-1, jobs=None, ram=None): """ Compute size and number of the batches of basic microscopy image slices analyzed in parallel. Parameters ---------- - frangi_scales: list (dtype=float) - analyzed spatial scales in [μm] - mem_growth_factor: float empirical memory growth factor of the Frangi filtering stage @@ -268,19 +264,16 @@ def config_frangi_batch(frangi_scales, mem_growth_factor=149.7, mem_fudge_factor if jobs is None: jobs = num_cpu - # number of spatial scales - num_scales = len(frangi_scales) - # initialize slice batch size - slice_batch_size = np.min([jobs // num_scales, num_cpu]).astype(int) + slice_batch_size = np.min([jobs, num_cpu]).astype(int) if slice_batch_size == 0: slice_batch_size = 1 # get image slice size - slice_size = get_slice_size(ram, mem_growth_factor, mem_fudge_factor, slice_batch_size, num_scales) + slice_size = get_slice_size(ram, mem_growth_factor, mem_fudge_factor, slice_batch_size) while slice_size < min_slice_size: slice_batch_size -= 1 - slice_size = get_slice_size(ram, mem_growth_factor, mem_fudge_factor, slice_batch_size, num_scales) + slice_size = get_slice_size(ram, mem_growth_factor, mem_fudge_factor, slice_batch_size) return slice_batch_size, slice_size