Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent error when FSC does not drop below threshold. #5

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/ttfsc/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def ttfsc_cli(
str, typer.Option("--plot-matplotlib-style", rich_help_panel="Plotting options")
] = "default",
mask: Annotated[Masking, typer.Option("--mask", rich_help_panel="Masking options")] = Masking.none,
mask_file: Annotated[
Optional[Path],
typer.Option(
"--mask-file", help="Path to custom mask file (required when mask=custom)", rich_help_panel="Masking options"
),
] = None,
mask_radius_angstroms: Annotated[
float, typer.Option("--mask-radius-angstroms", rich_help_panel="Masking options")
] = 100.0,
Expand All @@ -54,12 +60,17 @@ def ttfsc_cli(
float, typer.Option("--correct-from-fraction-of-estimated-resolution", rich_help_panel="Masking correction options")
] = 0.5,
) -> None:
if mask == Masking.custom and mask_file is None:
raise typer.BadParameter("--mask-file is required when using --mask=custom")
if mask == Masking.sphere and mask_file is not None:
rprint("[yellow]Warning: --mask-file is ignored when using --mask=sphere[/yellow]")
result = ttfsc(
map1=map1,
map2=map2,
pixel_spacing_angstroms=pixel_spacing_angstroms,
fsc_threshold=fsc_threshold,
mask=mask,
mask_filename=mask_file,
mask_radius_angstroms=mask_radius_angstroms,
mask_soft_edge_width_pixels=mask_soft_edge_width_pixels,
correct_for_masking=correct_for_masking,
Expand Down
1 change: 1 addition & 0 deletions src/ttfsc/_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
class Masking(str, Enum):
none = "none"
sphere = "sphere"
custom = "custom"


class TTFSCResult(BaseModel):
Expand Down
66 changes: 54 additions & 12 deletions src/ttfsc/_masking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mrcfile
import torch
from torch_fourier_shell_correlation import fsc

Expand Down Expand Up @@ -58,12 +59,18 @@ def calculate_noise_injected_fsc(result: TTFSCResult) -> None:
result.fsc_values_corrected[to_correct] - result.fsc_values_masked_randomized[to_correct]
) / (1.0 - result.fsc_values_masked_randomized[to_correct])

result.estimated_resolution_frequency_pixel = float(
result.frequency_pixels[(result.fsc_values_corrected < result.fsc_threshold).nonzero()[0] - 1]
)
result.estimated_resolution_angstrom = float(
result.resolution_angstroms[(result.fsc_values_corrected < result.fsc_threshold).nonzero()[0] - 1]
)
# Find indices where FSC is below threshold
below_threshold_indices = (result.fsc_values_corrected < result.fsc_threshold).nonzero()

if len(below_threshold_indices) > 0:
# Use the first crossing point if it exists
index = below_threshold_indices[0] - 1
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[index])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[index])
else:
# If no values below threshold, use the highest frequency (Nyquist)
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[-1])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[-1])
result.estimated_resolution_angstrom_corrected = result.estimated_resolution_angstrom


Expand Down Expand Up @@ -93,12 +100,47 @@ def calculate_masked_fsc(result: TTFSCResult) -> None:
map2_tensor_masked = result.map2_tensor * result.mask_tensor
result.fsc_values_masked = fsc(map1_tensor_masked, map2_tensor_masked)

result.estimated_resolution_frequency_pixel = float(
result.frequency_pixels[(result.fsc_values_masked < result.fsc_threshold).nonzero()[0] - 1]
)
result.estimated_resolution_angstrom = float(
result.resolution_angstroms[(result.fsc_values_masked < result.fsc_threshold).nonzero()[0] - 1]
)
# Find indices where FSC is below threshold
below_threshold_indices = (result.fsc_values_masked < result.fsc_threshold).nonzero()

if len(below_threshold_indices) > 0:
# Use the first crossing point if it exists
index = below_threshold_indices[0] - 1
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[index])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[index])
else:
# If no values below threshold, use the highest frequency (Nyquist)
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[-1])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[-1])
result.estimated_resolution_angstrom_masked = result.estimated_resolution_angstrom

return
elif result.mask == Masking.custom:
if result.mask_filename is None:
raise ValueError("Must provide mask_filename for custom mask")

with mrcfile.open(result.mask_filename) as f:
result.mask_tensor = torch.tensor(f.data)

if result.mask_tensor.shape != result.map1_tensor.shape:
raise ValueError(f"Mask shape {result.mask_tensor.shape} does not match map shape {result.map1_tensor.shape}")

map1_tensor_masked = result.map1_tensor * result.mask_tensor
map2_tensor_masked = result.map2_tensor * result.mask_tensor
result.fsc_values_masked = fsc(map1_tensor_masked, map2_tensor_masked)

# Find indices where FSC is below threshold
below_threshold_indices = (result.fsc_values_masked < result.fsc_threshold).nonzero()

if len(below_threshold_indices) > 0:
# Use the first crossing point if it exists
index = below_threshold_indices[0] - 1
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[index])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[index])
else:
# If no values below threshold, use the highest frequency (Nyquist)
result.estimated_resolution_frequency_pixel = float(result.frequency_pixels[-1])
result.estimated_resolution_angstrom = float(result.resolution_angstroms[-1])
result.estimated_resolution_angstrom_masked = result.estimated_resolution_angstrom

return
Expand Down
19 changes: 17 additions & 2 deletions src/ttfsc/ttfsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def ttfsc(
pixel_spacing_angstroms: Optional[float] = None,
fsc_threshold: float = 0.143,
mask: Masking = Masking.none,
mask_filename: Optional[Path] = None,
mask_radius_angstroms: float = 100.0,
mask_soft_edge_width_pixels: int = 10,
correct_for_masking: bool = True,
Expand All @@ -40,6 +41,7 @@ def ttfsc(
pixel_spacing_angstroms (Optional[float]): Pixel spacing in Å/px. If not provided, it will be taken from the header.
fsc_threshold (float): FSC threshold value. Default is 0.143.
mask (Masking): Masking option to use. Default is Masking.none.
mask_filename (Optional[Path]): Path to the mask file. Default is None.
mask_radius_angstroms (float): Radius of the mask in Å. Default is 100.0.
mask_soft_edge_width_pixels (int): Width of the soft edge of the mask in pixels. Default is 10.
correct_for_masking (bool): Whether to correct for masking effects. Default is True.
Expand All @@ -59,6 +61,7 @@ def ttfsc(
pixel_spacing_angstroms=1.0,
fsc_threshold=0.143,
mask=Masking.soft,
mask_filename=Path("mask.mrc"),
mask_radius_angstroms=150.0,
mask_soft_edge_width_pixels=5,
correct_for_masking=True,
Expand All @@ -78,8 +81,19 @@ def ttfsc(

fsc_values_unmasked = fsc(map1_tensor, map2_tensor)

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])
# Find indices where FSC is below threshold
below_threshold_indices = (fsc_values_unmasked < fsc_threshold).nonzero()

if len(below_threshold_indices) > 0:
# Use the first crossing point if it exists
index = below_threshold_indices[0] - 1
estimated_resolution_frequency_pixel = float(frequency_pixels[index])
estimated_resolution_angstrom = float(resolution_angstroms[index])
else:
# If no values below threshold, use the highest frequency (Nyquist)
estimated_resolution_frequency_pixel = float(frequency_pixels[-1])
estimated_resolution_angstrom = float(resolution_angstroms[-1])

result = TTFSCResult(
map1=map1,
map1_tensor=map1_tensor,
Expand All @@ -99,6 +113,7 @@ def ttfsc(
from ._masking import calculate_masked_fsc

result.mask = mask
result.mask_filename = mask_filename
result.mask_radius_angstroms = mask_radius_angstroms
result.mask_soft_edge_width_pixels = mask_soft_edge_width_pixels
calculate_masked_fsc(result)
Expand Down
Loading