Skip to content

Commit

Permalink
Add npy support to load_memmap
Browse files Browse the repository at this point in the history
  • Loading branch information
fdeguire03 authored Jan 9, 2025
1 parent 857ae12 commit 2a14837
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions caiman/mmapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ def load_memmap(filename: str, mode: str = 'r') -> tuple[Any, tuple, int]:
"""
logger = logging.getLogger("caiman")
if pathlib.Path(filename).suffix != '.mmap':
allowed_extensions = {'.mmap', '.npy'}

extension = pathlib.Path(filename).suffix
if extension not in allowed_extensions:
logger.error(f"Unknown extension for file {filename}")
raise ValueError(f'Unknown file extension for file {filename} (should be .mmap)')
raise ValueError(f'Unknown file extension for file {filename} (should be .mmap or .npy)')
# Strip path components and use CAIMAN_DATA/example_movies
# TODO: Eventually get the code to save these in a different dir
#fn_without_path = os.path.split(filename)[-1]
Expand All @@ -63,7 +66,11 @@ def load_memmap(filename: str, mode: str = 'r') -> tuple[Any, tuple, int]:
#d1, d2, d3, T, order = int(fpart[-9]), int(fpart[-7]), int(fpart[-5]), int(fpart[-1]), fpart[-3]

filename = caiman.paths.fn_relocated(filename)
Yr = np.memmap(filename, mode=mode, shape=prepare_shape((d1 * d2 * d3, T)), dtype=np.float32, order=order)
if extension == '.mmap':
Yr = np.memmap(filename, mode=mode, shape=prepare_shape((d1 * d2 * d3, T)), dtype=np.float32, order=order)
elif extension == '.npy':
Yr = np.load(filename, mmap_mode=mode)

if d3 == 1:
return (Yr, (d1, d2), T)
else:
Expand Down

0 comments on commit 2a14837

Please sign in to comment.