Skip to content

Commit

Permalink
More polish for noise injection
Browse files Browse the repository at this point in the history
  • Loading branch information
jojoelfe committed Aug 21, 2024
1 parent e974447 commit 8fe3cb6
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions src/ttfsc/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def ttfsc_cli(
] = True,
correct_from_resolution: Annotated[
float, typer.Option("--correct-from_resolution", rich_help_panel="Masking correction options")
] = True,
] = 10.0,
) -> None:
with mrcfile.open(map1) as f:
map1_tensor = torch.tensor(f.data)
Expand Down Expand Up @@ -93,14 +93,15 @@ def ttfsc_cli(
norm=True,
device=map1_tensor_randomized.device,
)
to_correct = frequency_grid > (1 / correct_from_resolution) / pixel_spacing_angstroms
# Rotate phases at frequencies higher than 0.25
random_phases1 = torch.rand(frequency_grid[frequency_grid > 0.25].shape) * 2 * torch.pi
random_phases1 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi
random_phases1 = torch.complex(torch.cos(random_phases1), torch.sin(random_phases1))
random_phases2 = torch.rand(frequency_grid[frequency_grid > 0.25].shape) * 2 * torch.pi
random_phases2 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi
random_phases2 = torch.complex(torch.cos(random_phases2), torch.sin(random_phases2))

map1_tensor_randomized[frequency_grid > 0.25] *= random_phases1
map2_tensor_randomized[frequency_grid > 0.25] *= random_phases2
map1_tensor_randomized[to_correct] *= random_phases1
map2_tensor_randomized[to_correct] *= random_phases2

map1_tensor_randomized = torch.fft.irfftn(map1_tensor_randomized)
map2_tensor_randomized = torch.fft.irfftn(map2_tensor_randomized)
Expand All @@ -112,10 +113,12 @@ def ttfsc_cli(
map2_tensor = map2_tensor * mask_tensor
fsc_values_masked = fsc(map1_tensor, map2_tensor)
if correct_for_masking:
to_correct = frequency_pixels > (1 / correct_from_resolution) / pixel_spacing_angstroms
print(to_correct)
fsc_values_corrected = fsc_values_masked.clone()
fsc_values_corrected[frequency_pixels > 0.25] = (
fsc_values_corrected[frequency_pixels > 0.25] - fsc_values_masked_randomized[frequency_pixels > 0.25]
) / (1.0 - fsc_values_masked_randomized[frequency_pixels > 0.25])
fsc_values_corrected[to_correct] = (
fsc_values_corrected[to_correct] - fsc_values_masked_randomized[to_correct]
) / (1.0 - fsc_values_masked_randomized[to_correct])

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])
Expand Down

0 comments on commit 8fe3cb6

Please sign in to comment.