diff --git a/create_spectrogram.py b/create_spectrogram.py index eb00e0fc..6228044e 100644 --- a/create_spectrogram.py +++ b/create_spectrogram.py @@ -7,7 +7,7 @@ from scipy.io import wavfile -def create_spec_name(wav_name, output_dir): +def create_spec_name(wav_name, output_dir=None): """Creates appropriate path to the spectrogram from input .wav file and output directory. Args: @@ -18,7 +18,6 @@ def create_spec_name(wav_name, output_dir): """ spec_name = path.splitext(path.basename(wav_name))[0] if output_dir is not None: - Path(output_dir).mkdir(parents=True, exist_ok=True) spec_name = f"{path.normpath(output_dir)}/{spec_name}" return f"{spec_name}.png" @@ -46,7 +45,7 @@ def save_spectrogram(input_wav, plot_path=None, NFFT=256): `plot_path`: Path to the output spectrogram file. Default is `input_wav` with .png extension. `NFFT`: The number of data points used in each block for the FFT. A power 2 is most efficient. Returns: - None + Path to the spectrogram. """ samplerate, data = wavfile.read(input_wav) noverlap = NFFT // 2 if NFFT <= 128 else 128 @@ -68,11 +67,14 @@ def save_spectrogram(input_wav, plot_path=None, NFFT=256): if plot_path is None: plot_path = input_wav.replace(".wav", ".png") + else: + Path(path.dirname(plot_path)).mkdir(parents=True, exist_ok=True) plt.savefig(plot_path) plt.cla() plt.close("all") logging.info("Finished " + input_wav) + return plot_path if __name__ == "__main__": diff --git a/tests/ooi.wav b/tests/ooi.wav new file mode 100755 index 00000000..698e3126 Binary files /dev/null and b/tests/ooi.wav differ diff --git a/tests/orcasound.wav b/tests/orcasound.wav new file mode 100755 index 00000000..0fe1adb5 Binary files /dev/null and b/tests/orcasound.wav differ diff --git a/tests/test_spectrograms.py b/tests/test_spectrograms.py new file mode 100644 index 00000000..5d7e4eb8 --- /dev/null +++ b/tests/test_spectrograms.py @@ -0,0 +1,44 @@ +"""Unit tests for various util functions relating to spectrogram creation""" +import os.path + +import pytest + +from create_spectrogram import create_spec_name, save_spectrogram + + +@pytest.mark.parametrize( + "wav_name, output_dir, expected", + [ + ("2021-01-01T00-00-00-000.wav", None, "2021-01-01T00-00-00-000.png"), + ( + "2021-01-01T00-00-00-000.wav", + "spectrograms", + "spectrograms/2021-01-01T00-00-00-000.png", + ), + ( + "2021-01-01T00-00-00-000.wav", + "long/path/to/output", + "long/path/to/output/2021-01-01T00-00-00-000.png", + ), + ( + "2021-01-01T00-00-00-000.wav", + "end/slash/", + "end/slash/2021-01-01T00-00-00-000.png", + ), + ], +) +def test_create_spec_name(wav_name, output_dir, expected): + assert create_spec_name(wav_name, output_dir) == expected + + +@pytest.mark.parametrize( + "wav_name, plot_path", + [ + ("tests/ooi.wav", None), + ("tests/ooi.wav", "tests/path/to/output/ooi.png"), + ("tests/orcasound.wav", "tests/path/to/output/orcasound.png"), + ], +) +def test_save_spectrogram(wav_name, plot_path): + spec_path = save_spectrogram(wav_name, plot_path) + assert os.path.isfile(spec_path)