forked from InsightSoftwareConsortium/SimpleITK-Notebooks
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathintro_animation.py
176 lines (147 loc) · 5.99 KB
/
intro_animation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#
# Script for generating images illustrating the movement of images and change in
# similarity metric during registration.
#
import SimpleITK as sitk
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
import numpy as np
# Paste the two given images together. On the left will be image1 and on the right image2.
# image2 is also centered vertically in the combined image.
def write_combined_image(image1, image2, horizontal_space, file_name):
combined_image = sitk.Image(
(
image1.GetWidth() + image2.GetWidth() + horizontal_space,
max(image1.GetHeight(), image2.GetHeight()),
),
image1.GetPixelID(),
image1.GetNumberOfComponentsPerPixel(),
)
combined_image = sitk.Paste(
combined_image, image1, image1.GetSize(), (0, 0), (0, 0)
)
combined_image = sitk.Paste(
combined_image,
image2,
image2.GetSize(),
(0, 0),
(
image1.GetWidth() + horizontal_space,
round((combined_image.GetHeight() - image2.GetHeight()) / 2),
),
)
sitk.WriteImage(combined_image, file_name)
# Callback invoked when the StartEvent happens, sets up our new data.
def start_plot():
global metric_values, multires_iterations
metric_values = []
multires_iterations = []
# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
global metric_values, multires_iterations
del metric_values
del multires_iterations
# Close figure, we don't want to get a duplicate of the plot latter on.
plt.close()
# Callback invoked when the IterationEvent happens, update our data and
# save an image that includes a visualization of the registered images and
# the metric value plot.
def save_plot(registration_method, fixed, moving, transform, file_name_prefix):
#
# Plotting the similarity metric values, resolution changes are marked with
# a blue star.
#
global metric_values, multires_iterations
metric_values.append(registration_method.GetMetricValue())
# Plot the similarity metric values
plt.plot(metric_values, "r")
plt.plot(
multires_iterations,
[metric_values[index] for index in multires_iterations],
"b*",
)
plt.xlabel("Iteration Number", fontsize=12)
plt.ylabel("Metric Value", fontsize=12)
# Convert the plot to a SimpleITK image (works with the agg matplotlib backend, doesn't work
# with the default - the relevant method is canvas_tostring_rgb())
plt.gcf().canvas.draw()
plot_data = np.fromstring(plt.gcf().canvas.tostring_rgb(), dtype=np.uint8, sep="")
plot_data = plot_data.reshape(plt.gcf().canvas.get_width_height()[::-1] + (3,))
plot_image = sitk.GetImageFromArray(plot_data, isVector=True)
#
# Extract the central axial slice from the two volumes, compose it using the transformation
# and alpha blend it.
#
alpha = 0.7
central_index = round((fixed.GetSize())[2] / 2)
moving_transformed = sitk.Resample(
moving, fixed, transform, sitk.sitkLinear, 0.0, moving_image.GetPixelIDValue()
)
# Extract the central slice in xy and alpha blend them
combined = (1.0 - alpha) * fixed[:, :, central_index] + alpha * moving_transformed[
:, :, central_index
]
# Assume the alpha blended images are isotropic and rescale intensity
# Values so that they are in [0,255], convert the grayscale image to
# color (r,g,b).
combined_slices_image = sitk.Cast(sitk.RescaleIntensity(combined), sitk.sitkUInt8)
combined_slices_image = sitk.Compose(
combined_slices_image, combined_slices_image, combined_slices_image
)
write_combined_image(
combined_slices_image,
plot_image,
0,
file_name_prefix + format(len(metric_values), "03d") + ".png",
)
# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the
# metric_values list.
def update_multires_iterations():
global metric_values, multires_iterations
multires_iterations.append(len(metric_values))
if __name__ == "__main__":
# Read the images
fixed_image = sitk.ReadImage("training_001_ct.mha", sitk.sitkFloat32)
moving_image = sitk.ReadImage("training_001_mr_T1.mha", sitk.sitkFloat32)
# Initial alignment of the two volumes
transform = sitk.CenteredTransformInitializer(
fixed_image,
moving_image,
sitk.Euler3DTransform(),
sitk.CenteredTransformInitializerFilter.GEOMETRY,
)
# Multi-resolution rigid registration using Mutual Information
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)
registration_method.SetOptimizerAsGradientDescent(
learningRate=1.0,
numberOfIterations=100,
convergenceMinimumValue=1e-6,
convergenceWindowSize=10,
)
registration_method.SetOptimizerScalesFromPhysicalShift()
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
registration_method.SetInitialTransform(transform)
# Add all the callbacks responsible for plotting
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(
sitk.sitkMultiResolutionIterationEvent, update_multires_iterations
)
registration_method.AddCommand(
sitk.sitkIterationEvent,
lambda: save_plot(
registration_method,
fixed_image,
moving_image,
transform,
"output/iteration_plot",
),
)
registration_method.Execute(fixed_image, moving_image)