From 8fe3cb69265245d441aef2a561ce483016e388a6 Mon Sep 17 00:00:00 2001 From: Johannes Elferich Date: Tue, 20 Aug 2024 23:11:20 -0400 Subject: [PATCH] More polish for noise injection --- src/ttfsc/_cli.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/ttfsc/_cli.py b/src/ttfsc/_cli.py index 3142cdd..5060298 100644 --- a/src/ttfsc/_cli.py +++ b/src/ttfsc/_cli.py @@ -44,7 +44,7 @@ def ttfsc_cli( ] = True, correct_from_resolution: Annotated[ float, typer.Option("--correct-from_resolution", rich_help_panel="Masking correction options") - ] = True, + ] = 10.0, ) -> None: with mrcfile.open(map1) as f: map1_tensor = torch.tensor(f.data) @@ -93,14 +93,15 @@ def ttfsc_cli( norm=True, device=map1_tensor_randomized.device, ) + to_correct = frequency_grid > (1 / correct_from_resolution) / pixel_spacing_angstroms # Rotate phases at frequencies higher than 0.25 - random_phases1 = torch.rand(frequency_grid[frequency_grid > 0.25].shape) * 2 * torch.pi + random_phases1 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi random_phases1 = torch.complex(torch.cos(random_phases1), torch.sin(random_phases1)) - random_phases2 = torch.rand(frequency_grid[frequency_grid > 0.25].shape) * 2 * torch.pi + random_phases2 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi random_phases2 = torch.complex(torch.cos(random_phases2), torch.sin(random_phases2)) - map1_tensor_randomized[frequency_grid > 0.25] *= random_phases1 - map2_tensor_randomized[frequency_grid > 0.25] *= random_phases2 + map1_tensor_randomized[to_correct] *= random_phases1 + map2_tensor_randomized[to_correct] *= random_phases2 map1_tensor_randomized = torch.fft.irfftn(map1_tensor_randomized) map2_tensor_randomized = torch.fft.irfftn(map2_tensor_randomized) @@ -112,10 +113,12 @@ def ttfsc_cli( map2_tensor = map2_tensor * mask_tensor fsc_values_masked = fsc(map1_tensor, map2_tensor) if correct_for_masking: + to_correct = frequency_pixels > (1 / correct_from_resolution) / pixel_spacing_angstroms + print(to_correct) fsc_values_corrected = fsc_values_masked.clone() - fsc_values_corrected[frequency_pixels > 0.25] = ( - fsc_values_corrected[frequency_pixels > 0.25] - fsc_values_masked_randomized[frequency_pixels > 0.25] - ) / (1.0 - fsc_values_masked_randomized[frequency_pixels > 0.25]) + fsc_values_corrected[to_correct] = ( + fsc_values_corrected[to_correct] - fsc_values_masked_randomized[to_correct] + ) / (1.0 - fsc_values_masked_randomized[to_correct]) estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1]) estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])