Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General code cleanup and many minor fixes #50

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions basic_pitch/commandline_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generating_file_message(output_type: str) -> None:


def file_saved_confirmation(output_type: str, save_path: Union[pathlib.Path, str]) -> None:
"""Print a confirmation that the file was saved succesfully
"""Print a confirmation that the file was saved successfully

Args:
output_type: The kind of file that is being generated.
Expand All @@ -53,15 +53,16 @@ def file_saved_confirmation(output_type: str, save_path: Union[pathlib.Path, str
print(f" {OUTPUT_EMOJIS[output_type]} Saved to {save_path}")


def failed_to_save(output_type: str, save_path: Union[pathlib.Path, str]) -> None:
def failed_to_save(output_type: str, save_path: Union[pathlib.Path, str], e: Exception) -> None:
"""Print a failure to save message

Args:
output_type: The kind of file that is being generated.
save_path: The path to output file.
e: The exception that was raised.

"""
print(f"\n🚨 Failed to save {output_type.replace('_', ' ').lower()} to {save_path} \n")
print(f"\n🚨 Failed to save {output_type.replace('_', ' ').lower()} to {save_path} due to {e}\n")


@contextmanager
Expand Down
42 changes: 24 additions & 18 deletions basic_pitch/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def window_audio_file(audio_original: Tensor, hop_size: int) -> Tuple[Tensor, Li
window_times: list of {'start':.., 'end':...} objects (times in seconds)

"""
from tensorflow import expand_dims # imporing this here so the module loads faster
from tensorflow import expand_dims # importing this here so the module loads faster

audio_windowed = expand_dims(
signal.frame(audio_original, AUDIO_N_SAMPLES, hop_size, pad_end=True, pad_value=0),
Expand Down Expand Up @@ -85,7 +85,7 @@ def get_audio_input(
length of original audio file, in frames, BEFORE padding.

"""
assert overlap_len % 2 == 0, "overlap_length must be even, got {}".format(overlap_len)
assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}"

audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)

Expand Down Expand Up @@ -250,8 +250,8 @@ def save_note_events(
save_path: The location we're saving it
"""

with open(save_path, "w") as fhandle:
writer = csv.writer(fhandle, delimiter=",")
with open(save_path, "w") as f_handle:
writer = csv.writer(f_handle, delimiter=",")
writer.writerow(["start_time_s", "end_time_s", "pitch_midi", "velocity", "pitch_bend"])
for start_time, end_time, note_number, amplitude, pitch_bend in note_events:
row = [start_time, end_time, note_number, int(np.round(127 * amplitude))]
Expand Down Expand Up @@ -280,8 +280,8 @@ def predict(
onset_threshold: Minimum energy required for an onset to be considered present.
frame_threshold: Minimum energy requirement for a frame to be considered present.
minimum_note_length: The minimum allowed note length in frames.
minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
minimum_frequency: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_frequency: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
melodia_trick: Use the melodia post-processing step.
debug_file: An optional path to output debug data to. Useful for testing/verification.
Expand Down Expand Up @@ -364,15 +364,15 @@ def predict_and_save(
audio_path_list: List of file paths for the audio to run inference on.
output_directory: Directory to output MIDI and all other outputs derived from the model to.
save_midi: True to save midi.
sonify_midi: Whether or not to render audio from the MIDI and output it to a file.
sonify_midi: Whether to render audio from the MIDI and output it to a file.
save_model_outputs: True to save contours, onsets and notes from the model prediction.
save_notes: True to save note events.
model_path: Path to load the Keras saved model from. Can be local or on GCS.
onset_threshold: Minimum energy required for an onset to be considered present.
frame_threshold: Minimum energy requirement for a frame to be considered present.
minimum_note_length: The minimum allowed note length in frames.
minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
minimum_frequency: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_frequency: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
melodia_trick: Use the melodia post-processing step.
debug_file: An optional path to output debug data to. Useful for testing/verification.
Expand Down Expand Up @@ -400,34 +400,40 @@ def predict_and_save(
model_output_path = build_output_path(audio_path, output_directory, OutputExtensions.MODEL_OUTPUT_NPZ)
try:
np.savez(model_output_path, basic_pitch_model_output=model_output)
except Exception as e:
failed_to_save(OutputExtensions.MODEL_OUTPUT_NPZ.name, model_output_path, e)
else:
file_saved_confirmation(OutputExtensions.MODEL_OUTPUT_NPZ.name, model_output_path)
except Exception:
failed_to_save(OutputExtensions.MODEL_OUTPUT_NPZ.name, model_output_path)

if save_midi:
midi_path = build_output_path(audio_path, output_directory, OutputExtensions.MIDI)
try:
midi_data.write(str(midi_path))
except Exception as e:
failed_to_save(OutputExtensions.MIDI.name, midi_path, e)
else:
file_saved_confirmation(OutputExtensions.MIDI.name, midi_path)
except Exception:
failed_to_save(OutputExtensions.MIDI.name, midi_path)


if sonify_midi:
midi_sonify_path = build_output_path(audio_path, output_directory, OutputExtensions.MIDI_SONIFICATION)
try:
infer.sonify_midi(midi_data, midi_sonify_path, sr=sonification_samplerate)
except Exception as e:
failed_to_save(OutputExtensions.MIDI_SONIFICATION.name, midi_sonify_path, e)
else:
file_saved_confirmation(OutputExtensions.MIDI_SONIFICATION.name, midi_sonify_path)
except Exception:
failed_to_save(OutputExtensions.MIDI_SONIFICATION.name, midi_sonify_path)

if save_notes:
note_events_path = build_output_path(audio_path, output_directory, OutputExtensions.NOTE_EVENTS)
try:
save_note_events(note_events, note_events_path)
except Exception as e:
failed_to_save(OutputExtensions.NOTE_EVENTS.name, note_events_path, e)
else:
file_saved_confirmation(OutputExtensions.NOTE_EVENTS.name, note_events_path)
except Exception:
failed_to_save(OutputExtensions.NOTE_EVENTS.name, note_events_path)
except Exception:

except IOError:
print("🚨 Something went wrong 😔 - see the traceback below for details.")
print("")
print(traceback.format_exc())
Loading