Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AgainstEntropy committed Jan 17, 2025
1 parent 7c2bed0 commit 6b72c79
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions nexa/gguf/nexa_inference_audio_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ def run(self):
Run the audio language model inference loop.
"""
from nexa.gguf.llama._utils_spinner import start_spinner, stop_spinner

try:
while True:
audio_path = self._get_valid_audio_path()
user_input = nexa_prompt("Enter text (leave empty if no prompt): ")

stop_event, spinner_thread = start_spinner(
style="default",
style="default",
message=""
)

try:
# with suppress_stdout_stderr():
# response = self.inference(audio_path, user_input)
Expand All @@ -174,7 +174,7 @@ def run(self):
print() # '\n'
finally:
stop_spinner(stop_event, spinner_thread)

self.cleanup()

except KeyboardInterrupt:
Expand Down Expand Up @@ -252,8 +252,8 @@ def inference_streaming(self, audio_path: str, prompt: str = "") -> str:
)
res = 0
while res >= 0:
res = audio_lm_cpp.sample(oss)
res_str = audio_lm_cpp.get_str(oss).decode('utf-8')
res = audio_lm_cpp.sample(oss, is_qwen=self.is_qwen)
res_str = audio_lm_cpp.get_str(oss, is_qwen=self.is_qwen).decode('utf-8')

if '<|im_start|>' in res_str or '</s>' in res_str:
continue
Expand All @@ -268,7 +268,7 @@ def cleanup(self):
if self.context:
audio_lm_cpp.free(self.context, is_qwen=self.is_qwen)
self.context = None

if self.temp_file and os.path.exists(self.temp_file):
try:
os.remove(self.temp_file)
Expand All @@ -290,7 +290,7 @@ def _ensure_16khz(self, audio_path: str) -> str:
"""
try:
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Create tmp directory if it doesn't exist
tmp_dir = os.path.join(base_dir, 'tmp')
os.makedirs(tmp_dir, exist_ok=True)
Expand All @@ -300,7 +300,7 @@ def _ensure_16khz(self, audio_path: str) -> str:

if sr == 16000:
return audio_path

# Resample to 16kHz
print(f"Resampling audio from {sr} to 16000")
y_resampled = librosa.resample(y=y, orig_sr=sr, target_sr=16000)
Expand All @@ -309,15 +309,15 @@ def _ensure_16khz(self, audio_path: str) -> str:
original_name = os.path.splitext(os.path.basename(audio_path))[0]
tmp_filename = f"resampled_{original_name}_16khz_{int(time.time())}.wav"
tmp_path = os.path.join(tmp_dir, tmp_filename)

# Save the resampled audio
sf.write(
tmp_path,
y_resampled,
tmp_path,
y_resampled,
16000,
subtype='PCM_16'
)

# Store the path for cleanup
self.temp_file = tmp_path
return tmp_path
Expand Down

0 comments on commit 6b72c79

Please sign in to comment.