Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example app VAD default + memory reduction #217

Merged
merged 11 commits into from
Oct 8, 2024
11 changes: 8 additions & 3 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ struct ContentView: View {
@AppStorage("silenceThreshold") private var silenceThreshold: Double = 0.3
@AppStorage("useVAD") private var useVAD: Bool = true
@AppStorage("tokenConfirmationsNeeded") private var tokenConfirmationsNeeded: Double = 2
@AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .none
@AppStorage("concurrentWorkerCount") private var concurrentWorkerCount: Int = 4
@AppStorage("chunkingStrategy") private var chunkingStrategy: ChunkingStrategy = .vad
@AppStorage("encoderComputeUnits") private var encoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine
@AppStorage("decoderComputeUnits") private var decoderComputeUnits: MLComputeUnits = .cpuAndNeuralEngine

Expand Down Expand Up @@ -1269,12 +1270,15 @@ struct ContentView: View {

func transcribeCurrentFile(path: String) async throws {
// Load and convert buffer in a limited scope
Logging.debug("Loading audio file: \(path)")
let loadingStart = Date()
let audioFileSamples = try await Task {
try autoreleasepool {
let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path)
return AudioProcessor.convertBufferToArray(buffer: audioFileBuffer)
return try AudioProcessor.loadAudioAsFloatArray(fromPath: path)
}
}.value
Logging.debug("Loaded audio file in \(Date().timeIntervalSince(loadingStart)) seconds")


let transcription = try await transcribeAudioSamples(audioFileSamples)

Expand Down Expand Up @@ -1316,6 +1320,7 @@ struct ContentView: View {
withoutTimestamps: !enableTimestamps,
wordTimestamps: true,
clipTimestamps: seekClip,
concurrentWorkerCount: concurrentWorkerCount,
chunkingStrategy: chunkingStrategy
)

Expand Down
1 change: 0 additions & 1 deletion Sources/WhisperKit/Core/Audio/AudioChunker.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ open class VADAudioChunker: AudioChunking {
// Typically this will be the full audio file, unless seek points are explicitly provided
var startIndex = seekClipStart
while startIndex < seekClipEnd - windowPadding {
let currentFrameLength = audioArray.count
guard startIndex >= 0 && startIndex < audioArray.count else {
throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size")
}
Expand Down
67 changes: 58 additions & 9 deletions Sources/WhisperKit/Core/Audio/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ public extension AudioProcessing {
}

static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? {
let currentFrameLength = audioArray.count

guard startIndex >= 0 && startIndex < audioArray.count else {
Logging.error("startIndex is outside the buffer size")
return nil
Expand Down Expand Up @@ -197,7 +195,15 @@ public class AudioProcessor: NSObject, AudioProcessing {

let audioFileURL = URL(fileURLWithPath: audioFilePath)
let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)
return try loadAudio(fromFile: audioFile, startTime: startTime, endTime: endTime, maxReadFrameSize: maxReadFrameSize)
}

public static func loadAudio(
fromFile audioFile: AVAudioFile,
startTime: Double? = 0,
endTime: Double? = nil,
maxReadFrameSize: AVAudioFrameCount? = nil
) throws -> AVAudioPCMBuffer {
let sampleRate = audioFile.fileFormat.sampleRate
let channelCount = audioFile.fileFormat.channelCount
let frameLength = AVAudioFrameCount(audioFile.length)
Expand Down Expand Up @@ -243,13 +249,56 @@ public class AudioProcessor: NSObject, AudioProcessing {
}
}

public static func loadAudioAsFloatArray(
fromPath audioFilePath: String,
startTime: Double? = 0,
endTime: Double? = nil
) throws -> [Float] {
guard FileManager.default.fileExists(atPath: audioFilePath) else {
throw WhisperError.loadAudioFailed("Resource path does not exist \(audioFilePath)")
}

let audioFileURL = URL(fileURLWithPath: audioFilePath)
let audioFile = try AVAudioFile(forReading: audioFileURL, commonFormat: .pcmFormatFloat32, interleaved: false)
let inputSampleRate = audioFile.fileFormat.sampleRate
let inputFrameCount = AVAudioFrameCount(audioFile.length)
let inputDuration = Double(inputFrameCount) / inputSampleRate

let start = startTime ?? 0
let end = min(endTime ?? inputDuration, inputDuration)

// Load 10m of audio at a time to reduce peak memory while converting
// Particularly impactful for large audio files
let chunkDuration: Double = 60 * 10
var currentTime = start
var result: [Float] = []

while currentTime < end {
let chunkEnd = min(currentTime + chunkDuration, end)

try autoreleasepool {
let buffer = try loadAudio(
fromFile: audioFile,
startTime: currentTime,
endTime: chunkEnd
)

let floatArray = Self.convertBufferToArray(buffer: buffer)
result.append(contentsOf: floatArray)
}

currentTime = chunkEnd
}

return result
}

public static func loadAudio(at audioPaths: [String]) async -> [Result<[Float], Swift.Error>] {
await withTaskGroup(of: [(index: Int, result: Result<[Float], Swift.Error>)].self) { taskGroup -> [Result<[Float], Swift.Error>] in
for (index, audioPath) in audioPaths.enumerated() {
taskGroup.addTask {
do {
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
let audio = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
let audio = try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
return [(index: index, result: .success(audio))]
} catch {
return [(index: index, result: .failure(error))]
Expand Down Expand Up @@ -280,10 +329,10 @@ public class AudioProcessor: NSObject, AudioProcessing {
frameCount: AVAudioFrameCount? = nil,
maxReadFrameSize: AVAudioFrameCount = Constants.defaultAudioReadFrameSize
) -> AVAudioPCMBuffer? {
let inputFormat = audioFile.fileFormat
let inputSampleRate = audioFile.fileFormat.sampleRate
let inputStartFrame = audioFile.framePosition
let inputFrameCount = frameCount ?? AVAudioFrameCount(audioFile.length)
let inputDuration = Double(inputFrameCount) / inputFormat.sampleRate
let inputDuration = Double(inputFrameCount) / inputSampleRate
let endFramePosition = min(inputStartFrame + AVAudioFramePosition(inputFrameCount), audioFile.length + 1)

guard let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: channelCount) else {
Expand All @@ -305,8 +354,8 @@ public class AudioProcessor: NSObject, AudioProcessing {
let remainingFrames = AVAudioFrameCount(endFramePosition - audioFile.framePosition)
let framesToRead = min(remainingFrames, maxReadFrameSize)

let currentPositionInSeconds = Double(audioFile.framePosition) / inputFormat.sampleRate
let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputFormat.sampleRate
let currentPositionInSeconds = Double(audioFile.framePosition) / inputSampleRate
let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputSampleRate
Logging.debug("Resampling \(String(format: "%.2f", currentPositionInSeconds))s - \(String(format: "%.2f", nextPositionInSeconds))s")

do {
Expand Down Expand Up @@ -644,7 +693,7 @@ public class AudioProcessor: NSObject, AudioProcessing {
&propertySize,
&name
)
if status == noErr, let deviceNameCF = name?.takeUnretainedValue() as String? {
if status == noErr, let deviceNameCF = name?.takeRetainedValue() as String? {
deviceName = deviceNameCF
}

Expand Down
6 changes: 4 additions & 2 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,11 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
var hasAlignment = false
var isFirstTokenLogProbTooLow = false
let windowUUID = UUID()
DispatchQueue.global().async { [weak self] in
Task { [weak self] in
guard let self = self else { return }
self.shouldEarlyStop[windowUUID] = false
await MainActor.run {
self.shouldEarlyStop[windowUUID] = false
}
}
for tokenIndex in prefilledIndex..<loopCount {
let loopStart = Date()
Expand Down
11 changes: 10 additions & 1 deletion Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ public extension String {
}

extension AVAudioPCMBuffer {
/// Converts the buffer to a float array
func asFloatArray() throws -> [Float] {
guard let data = floatChannelData?.pointee else {
throw WhisperError.audioProcessingFailed("Error converting audio, missing floatChannelData")
}
return Array(UnsafeBufferPointer(start: data, count: Int(frameLength)))
}

/// Appends the contents of another buffer to the current buffer
func appendContents(of buffer: AVAudioPCMBuffer) -> Bool {
return appendContents(of: buffer, startingFrame: 0, frameCount: buffer.frameLength)
Expand Down Expand Up @@ -446,8 +454,9 @@ public func modelSupport(for deviceName: String, from config: ModelSupportConfig
/// Deprecated
@available(*, deprecated, message: "Subject to removal in a future version. Use modelSupport(for:from:) -> ModelSupport instead.")
@_disfavoredOverload
@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
public func modelSupport(for deviceName: String, from config: ModelSupportConfig? = nil) -> (default: String, disabled: [String]) {
let modelSupport = modelSupport(for: deviceName, from: config)
let modelSupport: ModelSupport = modelSupport(for: deviceName, from: config)
return (modelSupport.default, modelSupport.disabled)
}

Expand Down
49 changes: 26 additions & 23 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ open class WhisperKit {
open func detectLanguage(
audioPath: String
) async throws -> (language: String, langProbs: [String: Float]) {
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
// Only need the first 30s for language detection
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath, endTime: 30.0)
let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
return try await detectLangauge(audioArray: audioArray)
}
Expand Down Expand Up @@ -721,15 +722,17 @@ open class WhisperKit {
callback: TranscriptionCallback = nil
) async throws -> [TranscriptionResult] {
// Process input audio file into audio samples
let loadAudioStart = Date()
let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath)
let loadTime = Date().timeIntervalSince(loadAudioStart)
let audioArray = try await withThrowingTaskGroup(of: [Float].self) { group -> [Float] in
let convertAudioStart = Date()
defer {
let convertTime = Date().timeIntervalSince(convertAudioStart)
currentTimings.audioLoading = convertTime
Logging.debug("Audio loading and convert time: \(convertTime)")
logCurrentMemoryUsage("Audio Loading and Convert")
}

let convertAudioStart = Date()
let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer)
let convertTime = Date().timeIntervalSince(convertAudioStart)
currentTimings.audioLoading = loadTime + convertTime
Logging.debug("Audio loading time: \(loadTime), Audio convert time: \(convertTime)")
return try AudioProcessor.loadAudioAsFloatArray(fromPath: audioPath)
}

let transcribeResults: [TranscriptionResult] = try await transcribe(
audioArray: audioArray,
Expand Down Expand Up @@ -837,23 +840,23 @@ open class WhisperKit {
throw WhisperError.tokenizerUnavailable()
}

let childProgress = Progress()
progress.totalUnitCount += 1
progress.addChild(childProgress, withPendingUnitCount: 1)

let transcribeTask = TranscribeTask(
currentTimings: currentTimings,
progress: childProgress,
audioEncoder: audioEncoder,
featureExtractor: featureExtractor,
segmentSeeker: segmentSeeker,
textDecoder: textDecoder,
tokenizer: tokenizer
)

do {
try Task.checkCancellation()

let childProgress = Progress()
progress.totalUnitCount += 1
progress.addChild(childProgress, withPendingUnitCount: 1)

let transcribeTask = TranscribeTask(
currentTimings: currentTimings,
progress: childProgress,
audioEncoder: audioEncoder,
featureExtractor: featureExtractor,
segmentSeeker: segmentSeeker,
textDecoder: textDecoder,
tokenizer: tokenizer
)

let transcribeTaskResult = try await transcribeTask.run(
audioArray: audioArray,
decodeOptions: decodeOptions,
Expand Down
8 changes: 4 additions & 4 deletions Sources/WhisperKitCLI/CLIArguments.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ struct CLIArguments: ParsableArguments {
@Flag(help: "Simulate streaming transcription using the input audio file")
var streamSimulated: Bool = false

@Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited")
var concurrentWorkerCount: Int = 0
@Option(help: "Maximum concurrent inference, might be helpful when processing more than 1 audio file at the same time. 0 means unlimited. Default: 4")
var concurrentWorkerCount: Int = 4

@Option(help: "Chunking strategy for audio processing, `nil` means no chunking, `vad` means using voice activity detection")
var chunkingStrategy: String? = nil
@Option(help: "Chunking strategy for audio processing, `none` means no chunking, `vad` means using voice activity detection. Default: `vad`")
var chunkingStrategy: String = "vad"
}
14 changes: 3 additions & 11 deletions Sources/WhisperKitCLI/TranscribeCLI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ struct TranscribeCLI: AsyncParsableCommand {
cliArguments.audioPath = audioFiles.map { audioFolder + "/" + $0 }
}

if let chunkingStrategyRaw = cliArguments.chunkingStrategy {
if ChunkingStrategy(rawValue: chunkingStrategyRaw) == nil {
throw ValidationError("Wrong chunking strategy \"\(chunkingStrategyRaw)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })")
}
if ChunkingStrategy(rawValue: cliArguments.chunkingStrategy) == nil {
throw ValidationError("Wrong chunking strategy \"\(cliArguments.chunkingStrategy)\", valid strategies: \(ChunkingStrategy.allCases.map { $0.rawValue })")
}
}

Expand Down Expand Up @@ -318,12 +316,6 @@ struct TranscribeCLI: AsyncParsableCommand {
}

private func decodingOptions(task: DecodingTask) -> DecodingOptions {
let chunkingStrategy: ChunkingStrategy? =
if let chunkingStrategyRaw = cliArguments.chunkingStrategy {
ChunkingStrategy(rawValue: chunkingStrategyRaw)
} else {
nil
}
return DecodingOptions(
verbose: cliArguments.verbose,
task: task,
Expand All @@ -344,7 +336,7 @@ struct TranscribeCLI: AsyncParsableCommand {
firstTokenLogProbThreshold: cliArguments.firstTokenLogProbThreshold,
noSpeechThreshold: cliArguments.noSpeechThreshold ?? 0.6,
concurrentWorkerCount: cliArguments.concurrentWorkerCount,
chunkingStrategy: chunkingStrategy
chunkingStrategy: ChunkingStrategy(rawValue: cliArguments.chunkingStrategy)
)
}

Expand Down
5 changes: 4 additions & 1 deletion Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,11 @@ final class UnitTests: XCTestCase {
}

func testDecodingEarlyStopping() async throws {
let earlyStopTokenCount = 10
let options = DecodingOptions()
let continuationCallback: TranscriptionCallback = { (progress: TranscriptionProgress) -> Bool? in
false
// Stop after only 10 tokens (full test audio contains 16)
return progress.tokens.count <= earlyStopTokenCount
}

let result = try await XCTUnwrapAsync(
Expand All @@ -576,6 +578,7 @@ final class UnitTests: XCTestCase {
XCTAssertNotNil(resultWithWait)
let tokenCountWithWait = resultWithWait.segments.flatMap { $0.tokens }.count
let decodingTimePerTokenWithWait = resultWithWait.timings.decodingLoop / Double(tokenCountWithWait)
Logging.debug("Decoding loop without wait: \(result.timings.decodingLoop), with wait: \(resultWithWait.timings.decodingLoop)")

// Assert that the decoding predictions per token are not slower with the waiting
XCTAssertEqual(decodingTimePerTokenWithWait, decodingTimePerToken, accuracy: decodingTimePerToken, "Decoding predictions per token should not be significantly slower with waiting")
Expand Down