Skip to content

Commit

Permalink
Ensure transform.Function returns writable arrays (#35)
Browse files Browse the repository at this point in the history
* Add failing test

* Simplify failing test

* Add new test

* Fix linter

* Add fix to code

* Update tests/test_read_only_array.py

Co-authored-by: audeerington <[email protected]>

* Fix tests

---------

Co-authored-by: audeerington <[email protected]>
  • Loading branch information
hagenw and audeerington authored Jul 5, 2024
1 parent aa87c89 commit a665a49
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
5 changes: 5 additions & 0 deletions auglib/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ def __call__(
if signal.dtype != DTYPE:
signal = signal.astype(DTYPE)

# Ensure signal is not read-only
# (https://github.com/audeering/auglib/issues/31)
if not signal.flags["WRITEABLE"]:
signal = signal.copy()

if preserve_level:
signal_level = rms_db(signal)

Expand Down
91 changes: 91 additions & 0 deletions tests/test_read_only_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
import pytest

import auglib


auglib.seed(0)


def identity(signal, sampling_rate):
return signal


def read_only(
signal: np.array,
sampling_rate: int,
):
signal.setflags(write=False)
return signal


@pytest.mark.parametrize("signal", [[1, 1]])
@pytest.mark.parametrize("sampling_rate", [8000])
@pytest.mark.parametrize(
"transform",
[
auglib.transform.AMRNB(4750),
auglib.transform.Append(np.ones((1, 1))),
auglib.transform.AppendValue(1, unit="samples"),
auglib.transform.BabbleNoise([np.ones((1, 2))]),
auglib.transform.BandPass(1000, 200),
auglib.transform.BandStop(1000, 200),
auglib.transform.Clip(),
auglib.transform.ClipByRatio(0.05),
auglib.transform.CompressDynamicRange(-15, 1 / 4),
auglib.transform.Fade(in_dur=0.2, out_dur=0.7),
auglib.transform.FFTConvolve(np.ones((1, 1))),
auglib.transform.Function(identity),
auglib.transform.GainStage(3),
auglib.transform.HighPass(3000),
auglib.transform.LowPass(100),
auglib.transform.Mask(auglib.transform.Clip()),
auglib.transform.Mix(np.ones((1, 1))),
auglib.transform.NormalizeByPeak(),
auglib.transform.PinkNoise(),
auglib.transform.Prepend(np.ones((1, 1))),
auglib.transform.PrependValue(1, unit="samples"),
auglib.transform.Resample(4000),
auglib.transform.Shift(1, unit="samples"),
auglib.transform.Tone(100),
auglib.transform.Trim(start_pos=0, end_pos=1, unit="samples"),
auglib.transform.WhiteNoiseGaussian(),
auglib.transform.WhiteNoiseUniform(),
],
)
def test_compose_read_only(
signal: np.array,
sampling_rate: int,
transform: auglib.transform.Base,
):
r"""Test applying transform on read-only array.
Certain custom transforms
(e.g. when using sox.Transformer)
can return numpy arrays in read-only mode.
If other transforms try to write to this array,
without making a copy first,
they will fail, see
https://github.com/audeering/auglib/issues/31
Args:
signal: signal
sampling_rate: sampling rate in Hz
transform: transform
"""
signal = np.array(signal, dtype=auglib.core.transform.DTYPE)

# Apply transform to read-only signal
signal.setflags(write=False)
augmented_signal = transform(signal, sampling_rate)
assert augmented_signal.flags["WRITEABLE"]

# Apply transform in compose
# after transform that makes signal read-only
compose_transform = auglib.transform.Compose(
[auglib.transform.Function(read_only), transform]
)
augmented_signal = compose_transform(signal, sampling_rate)
assert augmented_signal.flags["WRITEABLE"]

0 comments on commit a665a49

Please sign in to comment.