From c74dfffd0c34c07f14fc28572d45e8bb7b1d27e7 Mon Sep 17 00:00:00 2001 From: Johannes Wagner Date: Mon, 13 Mar 2023 12:37:56 +0100 Subject: [PATCH] Speed up ProcessWithContext.process_*() (#104) * avoid use of pd.concat() * TST: test for bad index * TST: test for bad index --- audinterface/core/process_with_context.py | 103 ++++++++++++---------- tests/test_process_with_context.py | 12 +++ 2 files changed, 69 insertions(+), 46 deletions(-) diff --git a/audinterface/core/process_with_context.py b/audinterface/core/process_with_context.py index fa89ea5..64d3a3b 100644 --- a/audinterface/core/process_with_context.py +++ b/audinterface/core/process_with_context.py @@ -1,5 +1,6 @@ import collections import inspect +import itertools import typing import warnings @@ -190,13 +191,15 @@ def process_index( .. _audformat: https://audeering.github.io/audformat/data-format.html """ + utils.assert_index(index) + index = audformat.utils.to_segmented_index(index) - if index.empty: - return pd.Series(index=index, dtype=object) + if len(index) == 0: + return pd.Series([], index=index, dtype=object) files = index.levels[0] - ys = [None] * len(files) + ys = [] with audeer.progress_bar( files, @@ -204,24 +207,30 @@ def process_index( disable=not self.verbose, ) as pbar: for idx, file in enumerate(pbar): - desc = audeer.format_display_message(file, pbar=True) - pbar.set_description(desc, refresh=True) + + if self.verbose: # pragma: no cover + desc = audeer.format_display_message(file, pbar=True) + pbar.set_description(desc, refresh=True) + mask = index.isin([file], 0) select = index[mask].droplevel(0) + signal, sampling_rate = utils.read_audio(file, root=root) - ys[idx] = pd.Series( - self._process_signal_from_index( - signal, - sampling_rate, - select, - idx=idx, - root=root, - file=file, - ).values, - index=index[mask], + y = self._process_signal_from_index( + signal, + sampling_rate, + select, + idx=idx, + root=root, + file=file, ) - return pd.concat(ys) + ys.append(y) + + y = list(itertools.chain.from_iterable([x for x in ys])) + y = pd.Series(y, index) + + return y def _process_signal_from_index( self, @@ -232,36 +241,30 @@ def _process_signal_from_index( idx: int = 0, root: str = None, file: str = None, - ) -> pd.Series: - - utils.assert_index(index) + ) -> typing.Any: - if len(index) == 0: - y = pd.Series([], index=index, dtype=object) - else: - starts_i, ends_i = utils.segments_to_indices( - signal, - sampling_rate, - index, - ) - y = self._call( - signal, - sampling_rate, - starts_i, - ends_i, - idx=idx, - root=root, - file=file, + starts_i, ends_i = utils.segments_to_indices( + signal, + sampling_rate, + index, + ) + y = self._call( + signal, + sampling_rate, + starts_i, + ends_i, + idx=idx, + root=root, + file=file, + ) + if ( + not isinstance(y, collections.abc.Iterable) + or len(y) != len(index) + ): + raise RuntimeError( + 'process_func has to return a sequence of results, ' + f'matching the length {len(index)} of the index. ' ) - if ( - not isinstance(y, collections.abc.Iterable) - or len(y) != len(index) - ): - raise RuntimeError( - 'process_func has to return a sequence of results, ' - f'matching the length {len(index)} of the index. ' - ) - y = pd.Series(y, index=index) return y @@ -294,11 +297,19 @@ def process_signal_from_index( .. _audformat: https://audeering.github.io/audformat/data-format.html """ - return self._process_signal_from_index( + utils.assert_index(index) + + if len(index) == 0: + return pd.Series([], index=index, dtype=object) + + y = self._process_signal_from_index( signal, sampling_rate, index, ) + y = pd.Series(y, index) + + return y def _call( self, @@ -350,7 +361,7 @@ def __call__( r"""Apply processing to signal. This function processes the signal **without** transforming the output - into a :class:`pd.Series`. Instead it will return the raw processed + into a :class:`pd.Series`. Instead, it will return the raw processed signal. However, if channel selection, mixdown and/or resampling is enabled, the signal will be first remixed and resampled if the input sampling rate does not fit the expected sampling rate. diff --git a/tests/test_process_with_context.py b/tests/test_process_with_context.py index 5512fe6..3d95b98 100644 --- a/tests/test_process_with_context.py +++ b/tests/test_process_with_context.py @@ -133,6 +133,10 @@ def test_process_index(tmpdir): ) np.testing.assert_equal(np.atleast_2d(x[channel]), value) + # bad index + with pytest.raises(ValueError): + process.process_index(pd.Index([]), root=root) + @pytest.mark.parametrize( 'process_func,process_func_with_context,signal,sampling_rate,index', @@ -242,6 +246,14 @@ def test_process_index(tmpdir): ), marks=pytest.mark.xfail(raises=RuntimeError), ), + pytest.param( # not a valid index + lambda signal, sampling_rate: [0, 1], + lambda signal, sampling_rate, starts, ends: [0, 1], + np.random.random(5 * 44100), + 44100, + pd.Index([]), + marks=pytest.mark.xfail(raises=ValueError), + ), ], ) def test_process_signal_from_index(