Skip to content

Commit

Permalink
Upgrading a unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Jan 4, 2025
1 parent 7112e4a commit d739cc0
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions test/test_learn_wave_quadratic_phase_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test(
pixel_pitch = 3.74e-6,
number_of_frames = 3,
number_of_depth_layers = 10,
volume_depth = 10e-3,
volume_depth = 1e-2,
image_location_offset = 3e-2,
propagation_type = 'Bandlimited Angular Spectrum',
propagator_type = 'forward',
Expand All @@ -20,24 +20,35 @@ def test(
[0., 1., 0.],
[0., 0., 1.],
]),
lens_focus = 3e-2,
offsets = [
[-20, -20],
[0, 0],
[30, 50]
],
lens_focus = [
28e-3,
30e-3,
32e-3
],
aperture = None,
aperture_size = None,
method = 'conventional',
device = torch.device('cpu'),
output_directory = 'test_output'
):
odak.tools.check_directory(output_directory)
hologram_phases = torch.ones(number_of_frames, resolution[0], resolution[1], device = device)
hologram_phases = torch.zeros(number_of_frames, resolution[0], resolution[1], device = device)
hologram_amplitudes = torch.ones(number_of_frames, resolution[0], resolution[1], device = device)
for frame_id in range(number_of_frames):
wavelength = wavelengths[frame_id]
k = odak.learn.wave.wavenumber(wavelength)
lens_complex = odak.learn.wave.quadratic_phase_function(
nx = resolution[0],
ny = resolution[1],
k = k,
focal = lens_focus,
dx = pixel_pitch
focal = lens_focus[frame_id],
dx = pixel_pitch,
offset = offsets[frame_id]
)
lens_phase = odak.learn.wave.calculate_phase(lens_complex).to(device).unsqueeze(0) % (2. * torch.pi)
hologram_phases[frame_id] = lens_phase
Expand All @@ -63,7 +74,11 @@ def test(
method = method,
device = device
)
reconstruction_intensities = propagator.reconstruct(hologram_phases, amplitude = None)
reconstruction_intensities = propagator.reconstruct(
hologram_phases = hologram_phases,
amplitude = hologram_amplitudes
)
reconstruction_intensities_rgb = torch.sum(reconstruction_intensities, dim = 0)
for frame_id in range(reconstruction_intensities.shape[0]):
for depth_id in range(reconstruction_intensities.shape[1]):
reconstruction_intensity = reconstruction_intensities[frame_id, depth_id]
Expand All @@ -73,6 +88,14 @@ def test(
cmin = 0.,
cmax = reconstruction_intensities.max()
)
reconstruction_intensity = reconstruction_intensities_rgb[depth_id]
odak.learn.tools.save_image(
'{}/lens_reconstruction_image_rgb_{:03d}.png'.format(output_directory, depth_id),
reconstruction_intensity,
cmin = 0.,
cmax = reconstruction_intensities_rgb.max()
)

assert True == True


Expand Down

0 comments on commit d739cc0

Please sign in to comment.