diff --git a/src/miv_simulator/coding.py b/src/miv_simulator/coding.py new file mode 100644 index 0000000..70a12f8 --- /dev/null +++ b/src/miv_simulator/coding.py @@ -0,0 +1,112 @@ +import numpy as np +from numpy.typing import NDArray +from typing import Annotated as EventArray, Dict + +SpikeTimesLike = EventArray[NDArray[np.float_], "SpikeTimesLike ..."] +"""Potentially unsorted or scalar data that can be transformed into `SpikeTimes`""" + +SpikeTimes = EventArray[NDArray[np.float64], "SpikeTimes T ..."] +"""Sorted array of absolute spike times""" + + +# Spike train encodings (RLE, delta encoding, variable time binning etc.) + +BinarySparseSpikeTrainLike = EventArray[ + NDArray, "BinarySparseSpikeTrainLike ..." +] +"""Binary data that can be cast to the `BinarySparseSpikeTrain` format""" + + +BinarySparseSpikeTrain = EventArray[ + NDArray[np.int8], "BinarySparseSpikeTrain t_bin ..." +] +"""Binary spike train representation for a given temporal resolution""" + + +def _inspect(type_) -> Dict: + annotation = type_.__metadata__[0] + name, *dims = annotation.split(" ") + + return { + "annotation": annotation, + "name": name, + "dims": dims, + "dtype": type_.__origin__.__args__[1].__args__[0], + } + + +def _cast(a, a_type, r_type): # -> r_type + a_t, r_t = _inspect(a_type), _inspect(r_type) + if a_t["name"].replace("Like", "") != r_t["name"]: + raise ValueError( + f"Expected miv_simulator.typing.{r_t['name']}Like but found {a_t['name']}" + ) + v = np.array(a, dtype=r_t["dtype"]) + if len(v.shape) == 0: + return np.reshape( + v, + [ + 1, + ], + ) + return v + + +def cast_spike_times(a: SpikeTimesLike) -> SpikeTimes: + return np.sort(_cast(a, SpikeTimesLike, SpikeTimes), axis=0) + + +def cast_binary_sparse_spike_train( + a: BinarySparseSpikeTrainLike, +) -> BinarySparseSpikeTrain: + return _cast(a, BinarySparseSpikeTrainLike, BinarySparseSpikeTrain) + + +def spike_times_2_binary_sparse_spike_train( + array: SpikeTimesLike, temporal_resolution: float +) -> BinarySparseSpikeTrain: + a = cast_spike_times(array) + bins = np.floor(a / temporal_resolution).astype(int) + # since a is sorted, maximum is last value + spike_train = np.zeros(bins[-1] + 1, dtype=np.int8) + spike_train[bins] = 1 + return spike_train + + +def binary_sparse_spike_train_2_spike_times( + array: BinarySparseSpikeTrainLike, temporal_resolution: float +) -> SpikeTimes: + a = cast_binary_sparse_spike_train(array) + spike_indices = np.where(a == 1)[0] + spike_times = spike_indices * temporal_resolution + return spike_times + + +def adjust_temporal_resolution( + array: BinarySparseSpikeTrainLike, + original_resolution: float, + target_resolution: float, +) -> BinarySparseSpikeTrain: + a = cast_binary_sparse_spike_train(array) + + ratio = target_resolution / original_resolution + if ratio == 1: + return a + + new_length = int(a.shape[0] * ratio) + new_spike_train = np.zeros(new_length, dtype=np.int8) + + # up + if ratio > 1: + for idx, val in enumerate(a): + start = int(idx * ratio) + end = int((idx + 1) * ratio) + new_spike_train[start:end] = val + + # down + elif ratio < 1: + for idx in range(0, len(a), int(1 / ratio)): + if np.any(a[idx : idx + int(1 / ratio)]): + new_spike_train[idx // int(1 / ratio)] = 1 + + return new_spike_train diff --git a/src/miv_simulator/config.py b/src/miv_simulator/config.py index 3b9f062..e327868 100644 --- a/src/miv_simulator/config.py +++ b/src/miv_simulator/config.py @@ -3,13 +3,14 @@ BaseModel as _BaseModel, Field, conlist, + AfterValidator, + BeforeValidator, ) from typing import Literal, Dict, Any, List, Tuple, Optional, Union, Callable from enum import IntEnum from collections import defaultdict import numpy as np from typing_extensions import Annotated -from pydantic.functional_validators import AfterValidator, BeforeValidator # Definitions diff --git a/src/miv_simulator/interface/neuroh5_graph.py b/src/miv_simulator/interface/neuroh5_graph.py index c7a64e0..cdfe83c 100644 --- a/src/miv_simulator/interface/neuroh5_graph.py +++ b/src/miv_simulator/interface/neuroh5_graph.py @@ -1,7 +1,7 @@ from machinable import Component -from miv_simulator import simulator from neuroh5.io import read_population_names from typing import Dict +from miv_simulator.utils.io import H5FileManager class NeuroH5Graph(Component): @@ -12,7 +12,7 @@ def __init__(self, *args, **kwargs): @property def graph(self) -> None: if self._graph is None: - self._graph = simulator.nh5.Graph(self.local_directory()) + self._graph = H5FileManager(self.local_directory()) return self._graph def __call__(self) -> None: diff --git a/src/miv_simulator/simulator/__init__.py b/src/miv_simulator/simulator/__init__.py index a206e03..0d6cea2 100644 --- a/src/miv_simulator/simulator/__init__.py +++ b/src/miv_simulator/simulator/__init__.py @@ -1,14 +1,14 @@ __doc__ = """Contains the end-user public API of the MiV-Simulator""" -from miv_simulator.utils.io import create_neural_h5 +from miv_simulator.simulator.distribute_synapses import distribute_synapses +from miv_simulator.simulator.generate_connections import generate_connections from miv_simulator.simulator.generate_network_architecture import ( generate_network_architecture, ) -from miv_simulator.simulator.measure_distances import measure_distances from miv_simulator.simulator.generate_synapse_forest import ( generate_synapse_forest, ) -from miv_simulator.simulator.distribute_synapses import distribute_synapses -from miv_simulator.simulator.generate_connections import generate_connections -from miv_simulator.simulator import nh5 +from miv_simulator.simulator.measure_distances import measure_distances +from miv_simulator.simulator.execution_environment import ExecutionEnvironment +from miv_simulator.utils.io import create_neural_h5 from miv_simulator.utils.neuron import configure_hoc diff --git a/src/miv_simulator/simulator/execution_environment.py b/src/miv_simulator/simulator/execution_environment.py new file mode 100644 index 0000000..dbf7e15 --- /dev/null +++ b/src/miv_simulator/simulator/execution_environment.py @@ -0,0 +1,356 @@ +from typing import Optional +from miv_simulator.utils import AbstractEnv +from mpi4py import MPI +from collections import defaultdict, namedtuple +from neuron import h +import logging +from miv_simulator.network import make_cells, connect_gjs, connect_cells +import time +import random +from miv_simulator import config +from miv_simulator.synapses import SynapseAttributes + +from neuroh5.io import ( + read_cell_attribute_info, + read_population_names, + read_population_ranges, + read_projection_names, +) + +logger = logging.getLogger(__name__) + + +class ExecutionEnvironment(AbstractEnv): + """Manages the runtime state within the rank""" + + def __init__( + self, + comm: Optional[MPI.Intracomm] = None, + seed: Optional[int] = None, + ): + self.seed = random.Random(seed).randint(1, 2**16 - 1) + + if comm is None: + comm = MPI.COMM_WORLD + self.comm = comm + + # --- Resources + + self.gidset = set() + self.node_allocation = None # node rank map + + # --- Statistics + + self.mkcellstime = -0.0 + self.connectgjstime = -0.0 + self.connectcellstime = -0.0 + + # --- Graph + + self.cells = defaultdict(lambda: dict()) + self.artificial_cells = defaultdict(lambda: dict()) + self.biophys_cells = defaultdict(lambda: dict()) + self.spike_onset_delay = {} + self.recording_sets = {} + self.synapse_attributes = None + self.edge_count = defaultdict(dict) + self.syns_set = defaultdict(set) + + # --- State + self.cells_meta_data = None + self.connections_meta_data = None + + # --- Compat + + self.template_dict = {} + + # --- Simulator + + self.pc = h.pc + self.rank = int(self.pc.id()) + + # Spike time of all cells on this host + self.t_vec = h.Vector() + # Ids of spike times on this host + self.id_vec = h.Vector() + # Timestamps of intracellular traces on this host + self.t_rec = h.Vector() + + # --- miv_simulator.network.init equivalent + + def load_cells( + self, + filepath: str, + templates: str, + cell_types: config.CellTypes, + io_size: int = 0, + ): + if self.rank == 0: + logger.info("*** Creating cells...") + st = time.time() + + rank = self.comm.Get_rank() + if rank == 0: + color = 1 + else: + color = 0 + ## comm0 includes only rank 0 + comm0 = self.comm.Split(color, 0) + + cell_attribute_info = None + population_ranges = None + population_names = None + if rank == 0: + population_names = read_population_names(filepath, comm0) + (population_ranges, _) = read_population_ranges(filepath, comm0) + cell_attribute_info = read_cell_attribute_info( + filepath, population_names, comm=comm0 + ) + logger.info(f"population_names = {str(population_names)}") + logger.info(f"population_ranges = {str(population_ranges)}") + logger.info(f"attribute info: {str(cell_attribute_info)}") + population_ranges = self.comm.bcast(population_ranges, root=0) + population_names = self.comm.bcast(population_names, root=0) + cell_attribute_info = self.comm.bcast(cell_attribute_info, root=0) + + # TODO: refactor from declarative to imperative + celltypes = dict(cell_types) + typenames = sorted(celltypes.keys()) + for k in typenames: + population_range = population_ranges.get(k, None) + if population_range is not None: + celltypes[k]["start"] = population_ranges[k][0] + celltypes[k]["num"] = population_ranges[k][1] + if "mechanism file" in celltypes[k]: + if isinstance(celltypes[k]["mechanism file"], str): + celltypes[k]["mech_file_path"] = celltypes[k][ + "mechanism file" + ] + mech_dict = None + if rank == 0: + mech_file_path = celltypes[k]["mech_file_path"] + if self.config_prefix is not None: + mech_file_path = os.path.join( + self.config_prefix, mech_file_path + ) + mech_dict = read_from_yaml(mech_file_path) + else: + mech_dict = celltypes[k]["mechanism file"] + mech_dict = self.comm.bcast(mech_dict, root=0) + celltypes[k]["mech_dict"] = mech_dict + if "synapses" in celltypes[k]: + synapses_dict = celltypes[k]["synapses"] + if "weights" in synapses_dict: + weights_config = synapses_dict["weights"] + if isinstance(weights_config, list): + weights_dicts = weights_config + else: + weights_dicts = [weights_config] + for weights_dict in weights_dicts: + if "expr" in weights_dict: + expr = weights_dict["expr"] + parameter = weights_dict["parameter"] + const = weights_dict.get("const", {}) + clos = ExprClosure(parameter, expr, const) + weights_dict["closure"] = clos + synapses_dict["weights"] = weights_dicts + + self.cells_meta_data = { + "source": filepath, + "cell_attribute_info": cell_attribute_info, + "population_ranges": population_ranges, + "population_names": population_names, + "celltypes": celltypes, + } + + comm0.Free() + + class _binding: + pass + + this = _binding() + this.__dict__.update( + { + # bound + "pc": self.pc, + "data_file_path": filepath, + "io_size": io_size, + "comm": self.comm, + "node_allocation": self.node_allocation, + "cells": self.cells, + "artificial_cells": self.artificial_cells, + "biophys_cells": self.biophys_cells, + "spike_onset_delay": self.spike_onset_delay, + "recording_sets": self.recording_sets, + "t_vec": self.t_vec, + "id_vec": self.id_vec, + "t_rec": self.t_rec, + # compat + "gapjunctions_file_path": None, # TODO + "gapjunctions": None, # TODO + "recording_profile": None, # TODO + "dt": 0.025, # TODO: understand the implications of this + "datasetName": "", + "gidset": self.gidset, + "SWC_Types": config.SWCTypesDef.__members__, + "template_paths": [templates], + "dataset_path": None, + "dataset_prefix": "", + "template_dict": self.template_dict, + "cell_attribute_info": cell_attribute_info, + "celltypes": cell_types, + "model_config": { + "Random Seeds": { + "Intracellular Recording Sample": self.seed + } + }, + } + ) + + make_cells(this) + + # HACK(frthjf): given its initial `None` primitive data type, the + # env.node_allocation copy at the end of make_cells will + # be lost when the local function stack is freed; + # fortunately, gidid is heap-allocated so we can + # simply repeat the set operation here + self.node_allocation = set() + for gid in self.gidset: + self.node_allocation.add(gid) + + self.mkcellstime = time.time() - st + if self.rank == 0: + logger.info(f"*** Cells created in {self.mkcellstime:.02f} s") + local_num_cells = sum(len(cells) for cells in self.cells.values()) + + logger.info(f"*** Rank {self.rank} created {local_num_cells} cells") + + st = time.time() + + connect_gjs(this) + + self.pc.setup_transfer() + self.connectgjstime = time.time() - st + if rank == 0: + logger.info( + f"*** Gap junctions created in {self.connectgjstime:.02f} s" + ) + + # -- user-space OptoStim and LFP etc. + + def load_connections( + self, + filepath: str, + cell_filepath: str, + synapses: config.Synapses, + io_size: int = 0, + ): + if not self.cells_meta_data: + raise RuntimeError("Please load the cells first using load_cells()") + + st = time.time() + if self.rank == 0: + logger.info(f"*** Creating connections:") + + rank = self.comm.Get_rank() + if rank == 0: + color = 1 + else: + color = 0 + ## comm0 includes only rank 0 + comm0 = self.comm.Split(color, 0) + + projection_dict = None + if rank == 0: + projection_dict = defaultdict(list) + for src, dst in read_projection_names(filepath, comm=comm0): + projection_dict[dst].append(src) + projection_dict = dict(projection_dict) + logger.info(f"projection_dict = {str(projection_dict)}") + projection_dict = self.comm.bcast(projection_dict, root=0) + comm0.Free() + + class _binding: + pass + + this = _binding() + this.__dict__.update( + { + "pc": self.pc, + "connectivity_file_path": filepath, + "forest_file_path": cell_filepath, + "io_size": io_size, + "comm": self.comm, + "node_allocation": self.node_allocation, + "edge_count": self.edge_count, + "biophys_cells": self.biophys_cells, + "gidset": self.gidset, + "recording_sets": self.recording_sets, + "microcircuit_inputs": False, + "use_cell_attr_gen": False, # TODO + "cleanup": True, + "projection_dict": projection_dict, + "Populations": config.PopulationsDef.__members__, + "connection_config": synapses, + "connection_velocity": { # TODO config + "PYR": 250, + "STIM": 250, + "PVBC": 250, + "OLM": 250, + }, + "SWC_Types": config.SWCTypesDef.__members__, + "celltypes": self.cells_meta_data["celltypes"], + } + ) + self.synapse_attributes = SynapseAttributes( + this, + # TODO: expose config + { + "AMPA": "LinExp2Syn", + "NMDA": "LinExp2SynNMDA", + "GABA_A": "LinExp2Syn", + "GABA_B": "LinExp2Syn", + }, + { + "Exp2Syn": { + "mech_file": "exp2syn.mod", + "mech_params": ["tau1", "tau2", "e"], + "netcon_params": {"weight": 0}, + "netcon_state": {}, + }, + "LinExp2Syn": { + "mech_file": "lin_exp2syn.mod", + "mech_params": ["tau_rise", "tau_decay", "e"], + "netcon_params": {"weight": 0, "g_unit": 1}, + "netcon_state": {}, + }, + "LinExp2SynNMDA": { + "mech_file": "lin_exp2synNMDA.mod", + "mech_params": [ + "tau_rise", + "tau_decay", + "e", + "mg", + "Kd", + "gamma", + "vshift", + ], + "netcon_params": {"weight": 0, "g_unit": 1}, + "netcon_state": {}, + }, + }, + ) + this.__dict__["synapse_attributes"] = self.synapse_attributes + + connect_cells(this) + + self.pc.set_maxstep(10.0) + + self.connectcellstime = time.time() - st + + if self.rank == 0: + logger.info( + f"*** Done creating connections: time = {self.connectcellstime:.02f} s" + ) + edge_count = int(sum(self.edge_count[dest] for dest in self.edge_count)) + logger.info(f"*** Rank {rank} created {edge_count} connections") diff --git a/src/miv_simulator/simulator/nh5.py b/src/miv_simulator/simulator/nh5.py deleted file mode 100644 index e1d02c1..0000000 --- a/src/miv_simulator/simulator/nh5.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Tuple - -import os -import pathlib - -import subprocess -import h5py - - -def _run(commands): - cmd = " ".join(commands) - print(cmd) - subprocess.check_output(commands) - - -def copy_dataset(f_src: h5py.File, f_dst: h5py.File, dset_path: str) -> None: - print(f"Copying {dset_path} from {f_src} to {f_dst} ...") - target_path = str(pathlib.Path(dset_path).parent) - f_src.copy(f_src[dset_path], f_dst[target_path]) - - -class Graph: - """Utility to manage NeuroH5 graph data""" - - def __init__(self, directory: str): - self.directory = directory - - def local_directory(self, *append: str, create: bool = False) -> str: - d = os.path.join(os.path.abspath(self.directory), *append) - if create: - os.makedirs(d, exist_ok=True) - return d - - @property - def cells_filepath(self) -> str: - return self.local_directory("cells.h5") - - @property - def connections_filepath(self) -> str: - return self.local_directory("connections.h5") - - def import_h5types(self, src: str): - with h5py.File(self.cells_filepath, "w") as f: - input_file = h5py.File(src, "r") - copy_dataset(input_file, f, "/H5Types") - input_file.close() - - with h5py.File(self.connections_filepath, "w") as f: - input_file = h5py.File(src, "r") - copy_dataset(input_file, f, "/H5Types") - input_file.close() - - def import_soma_coordinates(self, src: str, populations: Tuple[str] = ()): - with h5py.File(self.cells_filepath, "a") as f_dst: - grp = f_dst.create_group("Populations") - - for p in populations: - grp.create_group(p) - - for p in populations: - coords_dset_path = f"/Populations/{p}/Generated Coordinates" - coords_output_path = f"/Populations/{p}/Coordinates" - distances_dset_path = f"/Populations/{p}/Arc Distances" - with h5py.File(src, "r") as f_src: - copy_dataset(f_src, f_dst, coords_dset_path) - copy_dataset(f_src, f_dst, distances_dset_path) - - def import_synapse_attributes( - self, population: str, forest_file: str, synapses_file: str - ): - forest_dset_path = f"/Populations/{population}/Trees" - forest_syns_dset_path = f"/Populations/{population}/Synapse Attributes" - - cmd = [ - "h5copy", - "-p", - "-s", - forest_dset_path, - "-d", - forest_dset_path, - "-i", - forest_file, - "-o", - self.cells_filepath, - ] - _run(cmd) - - cmd = [ - "h5copy", - "-p", - "-s", - forest_syns_dset_path, - "-d", - forest_syns_dset_path, - "-i", - synapses_file, - "-o", - self.cells_filepath, - ] - _run(cmd) - - def import_projections(self, population: str, src: str): - projection_dset_path = f"/Projections/{population}" - cmd = [ - "h5copy", - "-p", - "-s", - projection_dset_path, - "-d", - projection_dset_path, - "-i", - src, - "-o", - self.connections_filepath, - ] - _run(cmd) - - def copy_stim_coordinates(self): - cmd = [ - "h5copy", - "-p", - "-s", - "/Populations/STIM/Generated Coordinates", - "-d", - "/Populations/STIM/Coordinates", - "-i", - self.cells_filepath, - "-o", - self.cells_filepath, - ] - _run(cmd) diff --git a/src/miv_simulator/synapses.py b/src/miv_simulator/synapses.py index 63c63bd..1121f2e 100644 --- a/src/miv_simulator/synapses.py +++ b/src/miv_simulator/synapses.py @@ -1695,7 +1695,11 @@ def insert_hoc_cell_syns( if "default" in syn_params: mech_params = syn_params["default"] else: - mech_params = syn_params[swc_type] + try: + mech_params = syn_params[swc_type] + except: + # default + mech_params = syn_params for syn_name, params in mech_params.items(): syn_mech = make_syn_mech( @@ -2106,7 +2110,8 @@ def config_syn( if not failed: setattr(syn, param, val(*param_vals)) else: - setattr(syn, param, val) + if val is not None: + setattr(syn, param, val) mech_param = True failed = False diff --git a/src/miv_simulator/utils/io.py b/src/miv_simulator/utils/io.py index 01b97df..32cae97 100644 --- a/src/miv_simulator/utils/io.py +++ b/src/miv_simulator/utils/io.py @@ -1,7 +1,9 @@ -from typing import Any, List, Optional, Union, Dict +from typing import Any, List, Optional, Union, Dict, Tuple import gc +import pathlib import os +import subprocess from collections import defaultdict from miv_simulator import config import h5py @@ -1129,3 +1131,127 @@ def query_cell_attributes(input_file, population_names, namespace_ids=None): else: namespace_id_lst = namespace_ids return namespace_id_lst, attr_info_dict + + +def _run(commands): + cmd = " ".join(commands) + print(cmd) + subprocess.check_output(commands) + + +def copy_dataset(f_src: h5py.File, f_dst: h5py.File, dset_path: str) -> None: + print(f"Copying {dset_path} from {f_src} to {f_dst} ...") + target_path = str(pathlib.Path(dset_path).parent) + f_src.copy(f_src[dset_path], f_dst[target_path]) + + +class H5FileManager: + """Utility to manage NeuroH5 simulator files""" + + def __init__(self, directory: str): + self.directory = directory + + def local_directory(self, *append: str, create: bool = False) -> str: + d = os.path.join(os.path.abspath(self.directory), *append) + if create: + os.makedirs(d, exist_ok=True) + return d + + @property + def cells_filepath(self) -> str: + return self.local_directory("cells.h5") + + @property + def connections_filepath(self) -> str: + return self.local_directory("connections.h5") + + def import_h5types(self, src: str): + with h5py.File(self.cells_filepath, "w") as f: + input_file = h5py.File(src, "r") + copy_dataset(input_file, f, "/H5Types") + input_file.close() + + with h5py.File(self.connections_filepath, "w") as f: + input_file = h5py.File(src, "r") + copy_dataset(input_file, f, "/H5Types") + input_file.close() + + def import_soma_coordinates(self, src: str, populations: Tuple[str] = ()): + with h5py.File(self.cells_filepath, "a") as f_dst: + grp = f_dst.create_group("Populations") + + for p in populations: + grp.create_group(p) + + for p in populations: + coords_dset_path = f"/Populations/{p}/Generated Coordinates" + coords_output_path = f"/Populations/{p}/Coordinates" + distances_dset_path = f"/Populations/{p}/Arc Distances" + with h5py.File(src, "r") as f_src: + copy_dataset(f_src, f_dst, coords_dset_path) + copy_dataset(f_src, f_dst, distances_dset_path) + + def import_synapse_attributes( + self, population: str, forest_file: str, synapses_file: str + ): + forest_dset_path = f"/Populations/{population}/Trees" + forest_syns_dset_path = f"/Populations/{population}/Synapse Attributes" + + cmd = [ + "h5copy", + "-p", + "-s", + forest_dset_path, + "-d", + forest_dset_path, + "-i", + forest_file, + "-o", + self.cells_filepath, + ] + _run(cmd) + + cmd = [ + "h5copy", + "-p", + "-s", + forest_syns_dset_path, + "-d", + forest_syns_dset_path, + "-i", + synapses_file, + "-o", + self.cells_filepath, + ] + _run(cmd) + + def import_projections(self, population: str, src: str): + projection_dset_path = f"/Projections/{population}" + cmd = [ + "h5copy", + "-p", + "-s", + projection_dset_path, + "-d", + projection_dset_path, + "-i", + src, + "-o", + self.connections_filepath, + ] + _run(cmd) + + def copy_stim_coordinates(self): + cmd = [ + "h5copy", + "-p", + "-s", + "/Populations/STIM/Generated Coordinates", + "-d", + "/Populations/STIM/Coordinates", + "-i", + self.cells_filepath, + "-o", + self.cells_filepath, + ] + _run(cmd) diff --git a/tests/test_coding.py b/tests/test_coding.py new file mode 100644 index 0000000..a08ef62 --- /dev/null +++ b/tests/test_coding.py @@ -0,0 +1,47 @@ +from miv_simulator import coding as t +import numpy as np + + +def test_coding_spike_times_vs_binary_sparse_spike_train(): + for a, b in [ + ([0.1, 0.3, 0.4, 0.85], [1, 1]), + ([0.8], [0, 1]), + ]: + result = t.spike_times_2_binary_sparse_spike_train(a, 0.5) + expected = np.array(b, dtype=np.int8) + assert np.array_equal(result, expected) + + for a, b in [ + ([1, 0, 1], [0.0, 1.0]), + ([0, 1], [0.5]), + ]: + spike_train = np.array(a, dtype=np.int8) + result = t.binary_sparse_spike_train_2_spike_times(spike_train, 0.5) + expected = np.array(b) + assert np.array_equal(result, expected) + + +def test_coding_adjust_temporal_resolution(): + spike_train = np.array([0, 1, 0, 1, 0], dtype=np.int8) + + # identity + adjusted = t.adjust_temporal_resolution(spike_train, 1, 1) + assert np.array_equal(adjusted, spike_train) + + # up + adjusted = t.adjust_temporal_resolution(spike_train, 0.5, 1) + expected = np.array([0, 0, 1, 1, 0, 0, 1, 1, 0, 0], dtype=np.int8) + assert np.array_equal(adjusted, expected) + + # down + adjusted = t.adjust_temporal_resolution(spike_train, 2, 1) + expected = np.array([1, 1], dtype=np.int8) + assert np.array_equal(adjusted, expected) + + +def test_coding_typing_cast(): + assert t.cast_spike_times(0.5).shape == (1,) + assert t.cast_spike_times([0.5, 0.1])[1] == 0.5 + assert t.cast_spike_times(int(1))[0] == float(1.0) + + assert t.cast_binary_sparse_spike_train(0.1)[0] == 0