-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference.py
119 lines (99 loc) · 4.5 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import librosa
import torch
from data_util import audioset_classes
from helpers.decode import batched_decode_preds
from helpers.encode import ManyHotEncoder
from models.atstframe.ATSTF_wrapper import ATSTWrapper
from models.beats.BEATs_wrapper import BEATsWrapper
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
from models.m2d.M2D_wrapper import M2DWrapper
from models.asit.ASIT_wrapper import ASiTWrapper
from models.prediction_wrapper import PredictionsWrapper
def sound_event_detection(args):
"""
Running Sound Event Detection on an audio clip.
"""
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
model_name = args.model_name
if model_name == "BEATs":
beats = BEATsWrapper()
model = PredictionsWrapper(beats, checkpoint="BEATs_strong_1")
elif model_name == "ATST-F":
atst = ATSTWrapper()
model = PredictionsWrapper(atst, checkpoint="ATST-F_strong_1")
elif model_name == "fpasst":
fpasst = FPaSSTWrapper()
model = PredictionsWrapper(fpasst, checkpoint="fpasst_strong_1")
elif model_name == "M2D":
m2d = M2DWrapper()
model = PredictionsWrapper(m2d, checkpoint="M2D_strong_1", embed_dim=m2d.m2d.cfg.feature_d)
elif model_name == "ASIT":
asit = ASiTWrapper()
model = PredictionsWrapper(asit, checkpoint="ASIT_strong_1")
else:
raise NotImplementedError(f"Model {model_name} not (yet) implemented")
model.eval()
model.to(device)
sample_rate = 16_000 # all our models are trained on 16 kHz audio
segment_duration = 10 # all models are trained on 10-second pieces
segment_samples = segment_duration * sample_rate
# load audio
(waveform, _) = librosa.core.load(args.audio_file, sr=sample_rate, mono=True)
waveform = torch.from_numpy(waveform[None, :]).to(device)
waveform_len = waveform.shape[1]
audio_len = waveform_len / sample_rate # in seconds
print("Audio length (seconds): ", audio_len)
# encoder manages decoding of model predictions into dataframes
# containing event labels, onsets and offsets
encoder = ManyHotEncoder(audioset_classes.as_strong_train_classes, audio_len=audio_len)
# split audio file into 10-second chunks
num_chunks = waveform_len // segment_samples + (waveform_len % segment_samples != 0)
all_predictions = []
# Process each 10-second chunk
for i in range(num_chunks):
start_idx = i * segment_samples
end_idx = min((i + 1) * segment_samples, waveform_len)
waveform_chunk = waveform[:, start_idx:end_idx]
# Pad the last chunk if it's shorter than 10 seconds
if waveform_chunk.shape[1] < segment_samples:
pad_size = segment_samples - waveform_chunk.shape[1]
waveform_chunk = torch.nn.functional.pad(waveform_chunk, (0, pad_size))
# Run inference for each chunk
with torch.no_grad():
mel = model.mel_forward(waveform_chunk)
y_strong, _ = model(mel)
# Collect predictions
all_predictions.append(y_strong)
# Concatenate all predictions along the time axis
y_strong = torch.cat(all_predictions, dim=2)
# convert into probabilities
y_strong = torch.sigmoid(y_strong)
(
scores_unprocessed,
scores_postprocessed,
decoded_predictions
) = batched_decode_preds(
y_strong.float(),
[args.audio_file],
encoder,
median_filter=args.median_window,
thresholds=args.detection_thresholds,
)
for th in decoded_predictions:
print("***************************************")
print(f"Detected events using threshold {th}:")
print(decoded_predictions[th].sort_values(by="onset"))
print("***************************************")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Example of parser. ')
# model names: [BEATs, ASIT, ATST-F, fpasst, M2D]
parser.add_argument('--model_name', type=str, default='BEATs')
parser.add_argument('--audio_file', type=str,
default='test_files/752547__iscence__milan_metro_coming_in_station.wav')
parser.add_argument('--detection_thresholds', type=float, default=(0.1, 0.2, 0.5))
parser.add_argument('--median_window', type=float, default=12)
parser.add_argument('--cuda', action='store_true', default=False)
args = parser.parse_args()
assert args.model_name in ["BEATs", "ASIT", "ATST-F", "fpasst", "M2D"]
sound_event_detection(args)