Skip to content

Commit

Permalink
Fix #2300 - fix progress bar and status for parallelized augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
akenmorris committed Nov 5, 2024
1 parent a69a344 commit 4249f34
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def point_based_aug(out_dir, orig_img_list, orig_point_list, num_samples, num_di
if sw_check_abort():
sw_message("Aborted")
return 0
sw_message("Generating " +str(index)+'/'+str(num_samples))
sw_progress(index / (num_samples+1))
if processes==1:
sw_message("Generating " +str(index)+'/'+str(num_samples))
sw_progress(index / (num_samples+1))
name = 'Generated_sample_' + Utils.pad_index(index)
# Generate embedding
sampled_embedding, base_index = PointSampler.sample()
Expand Down Expand Up @@ -112,9 +113,12 @@ def point_based_aug(out_dir, orig_img_list, orig_point_list, num_samples, num_di
# write world to local transformation information for generated particles
with open(out_dir + '/world_get_local_info.json', 'w') as f:
json.dump(world_get_local_info, f)
if processes!=1:
if processes != 1:
with mtps.Pool(processes=processes) as p:
gen_image_paths = p.map(generate_image, generate_image_params_list)
for index, gen_image_path in enumerate(p.imap(generate_image, generate_image_params_list), 1):
sw_message("Generating image " + str(index) + '/' + str(num_samples))
sw_progress(index / (num_samples + 1))
gen_image_paths.append(gen_image_path)
csv_file = out_dir + "TotalData.csv"
Utils.make_CSV(out_dir + "TotalData.csv", orig_img_list, orig_point_list, embedded_matrix, gen_image_paths, gen_points_paths, gen_embeddings)
return num_dim
Expand Down

0 comments on commit 4249f34

Please sign in to comment.