Skip to content

Commit

Permalink
Imperative execution environment (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf authored Oct 18, 2023
1 parent 1b3be46 commit 33cc698
Show file tree
Hide file tree
Showing 9 changed files with 658 additions and 142 deletions.
112 changes: 112 additions & 0 deletions src/miv_simulator/coding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import numpy as np
from numpy.typing import NDArray
from typing import Annotated as EventArray, Dict

SpikeTimesLike = EventArray[NDArray[np.float_], "SpikeTimesLike ..."]
"""Potentially unsorted or scalar data that can be transformed into `SpikeTimes`"""

SpikeTimes = EventArray[NDArray[np.float64], "SpikeTimes T ..."]
"""Sorted array of absolute spike times"""


# Spike train encodings (RLE, delta encoding, variable time binning etc.)

BinarySparseSpikeTrainLike = EventArray[
NDArray, "BinarySparseSpikeTrainLike ..."
]
"""Binary data that can be cast to the `BinarySparseSpikeTrain` format"""


BinarySparseSpikeTrain = EventArray[
NDArray[np.int8], "BinarySparseSpikeTrain t_bin ..."
]
"""Binary spike train representation for a given temporal resolution"""


def _inspect(type_) -> Dict:
annotation = type_.__metadata__[0]
name, *dims = annotation.split(" ")

return {
"annotation": annotation,
"name": name,
"dims": dims,
"dtype": type_.__origin__.__args__[1].__args__[0],
}


def _cast(a, a_type, r_type): # -> r_type
a_t, r_t = _inspect(a_type), _inspect(r_type)
if a_t["name"].replace("Like", "") != r_t["name"]:
raise ValueError(
f"Expected miv_simulator.typing.{r_t['name']}Like but found {a_t['name']}"
)
v = np.array(a, dtype=r_t["dtype"])
if len(v.shape) == 0:
return np.reshape(
v,
[
1,
],
)
return v


def cast_spike_times(a: SpikeTimesLike) -> SpikeTimes:
return np.sort(_cast(a, SpikeTimesLike, SpikeTimes), axis=0)


def cast_binary_sparse_spike_train(
a: BinarySparseSpikeTrainLike,
) -> BinarySparseSpikeTrain:
return _cast(a, BinarySparseSpikeTrainLike, BinarySparseSpikeTrain)


def spike_times_2_binary_sparse_spike_train(
array: SpikeTimesLike, temporal_resolution: float
) -> BinarySparseSpikeTrain:
a = cast_spike_times(array)
bins = np.floor(a / temporal_resolution).astype(int)
# since a is sorted, maximum is last value
spike_train = np.zeros(bins[-1] + 1, dtype=np.int8)
spike_train[bins] = 1
return spike_train


def binary_sparse_spike_train_2_spike_times(
array: BinarySparseSpikeTrainLike, temporal_resolution: float
) -> SpikeTimes:
a = cast_binary_sparse_spike_train(array)
spike_indices = np.where(a == 1)[0]
spike_times = spike_indices * temporal_resolution
return spike_times


def adjust_temporal_resolution(
array: BinarySparseSpikeTrainLike,
original_resolution: float,
target_resolution: float,
) -> BinarySparseSpikeTrain:
a = cast_binary_sparse_spike_train(array)

ratio = target_resolution / original_resolution
if ratio == 1:
return a

new_length = int(a.shape[0] * ratio)
new_spike_train = np.zeros(new_length, dtype=np.int8)

# up
if ratio > 1:
for idx, val in enumerate(a):
start = int(idx * ratio)
end = int((idx + 1) * ratio)
new_spike_train[start:end] = val

# down
elif ratio < 1:
for idx in range(0, len(a), int(1 / ratio)):
if np.any(a[idx : idx + int(1 / ratio)]):
new_spike_train[idx // int(1 / ratio)] = 1

return new_spike_train
3 changes: 2 additions & 1 deletion src/miv_simulator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
BaseModel as _BaseModel,
Field,
conlist,
AfterValidator,
BeforeValidator,
)
from typing import Literal, Dict, Any, List, Tuple, Optional, Union, Callable
from enum import IntEnum
from collections import defaultdict
import numpy as np
from typing_extensions import Annotated
from pydantic.functional_validators import AfterValidator, BeforeValidator

# Definitions

Expand Down
4 changes: 2 additions & 2 deletions src/miv_simulator/interface/neuroh5_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from machinable import Component
from miv_simulator import simulator
from neuroh5.io import read_population_names
from typing import Dict
from miv_simulator.utils.io import H5FileManager


class NeuroH5Graph(Component):
Expand All @@ -12,7 +12,7 @@ def __init__(self, *args, **kwargs):
@property
def graph(self) -> None:
if self._graph is None:
self._graph = simulator.nh5.Graph(self.local_directory())
self._graph = H5FileManager(self.local_directory())
return self._graph

def __call__(self) -> None:
Expand Down
10 changes: 5 additions & 5 deletions src/miv_simulator/simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
__doc__ = """Contains the end-user public API of the MiV-Simulator"""

from miv_simulator.utils.io import create_neural_h5
from miv_simulator.simulator.distribute_synapses import distribute_synapses
from miv_simulator.simulator.generate_connections import generate_connections
from miv_simulator.simulator.generate_network_architecture import (
generate_network_architecture,
)
from miv_simulator.simulator.measure_distances import measure_distances
from miv_simulator.simulator.generate_synapse_forest import (
generate_synapse_forest,
)
from miv_simulator.simulator.distribute_synapses import distribute_synapses
from miv_simulator.simulator.generate_connections import generate_connections
from miv_simulator.simulator import nh5
from miv_simulator.simulator.measure_distances import measure_distances
from miv_simulator.simulator.execution_environment import ExecutionEnvironment
from miv_simulator.utils.io import create_neural_h5
from miv_simulator.utils.neuron import configure_hoc
Loading

0 comments on commit 33cc698

Please sign in to comment.