Skip to content

Commit

Permalink
Speed up ProcessWithContext.process_*() (#104)
Browse files Browse the repository at this point in the history
* avoid use of pd.concat()

* TST: test for bad index

* TST: test for bad index
  • Loading branch information
frankenjoe authored Mar 13, 2023
1 parent f4687e1 commit c74dfff
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 46 deletions.
103 changes: 57 additions & 46 deletions audinterface/core/process_with_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import inspect
import itertools
import typing
import warnings

Expand Down Expand Up @@ -190,38 +191,46 @@ def process_index(
.. _audformat: https://audeering.github.io/audformat/data-format.html
"""
utils.assert_index(index)

index = audformat.utils.to_segmented_index(index)

if index.empty:
return pd.Series(index=index, dtype=object)
if len(index) == 0:
return pd.Series([], index=index, dtype=object)

files = index.levels[0]
ys = [None] * len(files)
ys = []

with audeer.progress_bar(
files,
total=len(files),
disable=not self.verbose,
) as pbar:
for idx, file in enumerate(pbar):
desc = audeer.format_display_message(file, pbar=True)
pbar.set_description(desc, refresh=True)

if self.verbose: # pragma: no cover
desc = audeer.format_display_message(file, pbar=True)
pbar.set_description(desc, refresh=True)

mask = index.isin([file], 0)
select = index[mask].droplevel(0)

signal, sampling_rate = utils.read_audio(file, root=root)
ys[idx] = pd.Series(
self._process_signal_from_index(
signal,
sampling_rate,
select,
idx=idx,
root=root,
file=file,
).values,
index=index[mask],
y = self._process_signal_from_index(
signal,
sampling_rate,
select,
idx=idx,
root=root,
file=file,
)

return pd.concat(ys)
ys.append(y)

y = list(itertools.chain.from_iterable([x for x in ys]))
y = pd.Series(y, index)

return y

def _process_signal_from_index(
self,
Expand All @@ -232,36 +241,30 @@ def _process_signal_from_index(
idx: int = 0,
root: str = None,
file: str = None,
) -> pd.Series:

utils.assert_index(index)
) -> typing.Any:

if len(index) == 0:
y = pd.Series([], index=index, dtype=object)
else:
starts_i, ends_i = utils.segments_to_indices(
signal,
sampling_rate,
index,
)
y = self._call(
signal,
sampling_rate,
starts_i,
ends_i,
idx=idx,
root=root,
file=file,
starts_i, ends_i = utils.segments_to_indices(
signal,
sampling_rate,
index,
)
y = self._call(
signal,
sampling_rate,
starts_i,
ends_i,
idx=idx,
root=root,
file=file,
)
if (
not isinstance(y, collections.abc.Iterable)
or len(y) != len(index)
):
raise RuntimeError(
'process_func has to return a sequence of results, '
f'matching the length {len(index)} of the index. '
)
if (
not isinstance(y, collections.abc.Iterable)
or len(y) != len(index)
):
raise RuntimeError(
'process_func has to return a sequence of results, '
f'matching the length {len(index)} of the index. '
)
y = pd.Series(y, index=index)

return y

Expand Down Expand Up @@ -294,11 +297,19 @@ def process_signal_from_index(
.. _audformat: https://audeering.github.io/audformat/data-format.html
"""
return self._process_signal_from_index(
utils.assert_index(index)

if len(index) == 0:
return pd.Series([], index=index, dtype=object)

y = self._process_signal_from_index(
signal,
sampling_rate,
index,
)
y = pd.Series(y, index)

return y

def _call(
self,
Expand Down Expand Up @@ -350,7 +361,7 @@ def __call__(
r"""Apply processing to signal.
This function processes the signal **without** transforming the output
into a :class:`pd.Series`. Instead it will return the raw processed
into a :class:`pd.Series`. Instead, it will return the raw processed
signal. However, if channel selection, mixdown and/or resampling
is enabled, the signal will be first remixed and resampled if the
input sampling rate does not fit the expected sampling rate.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_process_with_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def test_process_index(tmpdir):
)
np.testing.assert_equal(np.atleast_2d(x[channel]), value)

# bad index
with pytest.raises(ValueError):
process.process_index(pd.Index([]), root=root)


@pytest.mark.parametrize(
'process_func,process_func_with_context,signal,sampling_rate,index',
Expand Down Expand Up @@ -242,6 +246,14 @@ def test_process_index(tmpdir):
),
marks=pytest.mark.xfail(raises=RuntimeError),
),
pytest.param( # not a valid index
lambda signal, sampling_rate: [0, 1],
lambda signal, sampling_rate, starts, ends: [0, 1],
np.random.random(5 * 44100),
44100,
pd.Index([]),
marks=pytest.mark.xfail(raises=ValueError),
),
],
)
def test_process_signal_from_index(
Expand Down

0 comments on commit c74dfff

Please sign in to comment.