diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 956c2215..e12d6609 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -31,7 +31,7 @@ jobs: - name: Install Python dependencies run: | python -m pip install 'pocketsphinx<5' - python -m pip install git+https://github.com/openai/whisper.git soundfile + python -m pip install openai-whisper soundfile python -m pip install openai python -m pip install . - name: Test with unittest diff --git a/README.rst b/README.rst index 410e289d..d86ba5bf 100644 --- a/README.rst +++ b/README.rst @@ -169,7 +169,7 @@ Whisper (for Whisper users) ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Whisper is **required if and only if you want to use whisper** (``recognizer_instance.recognize_whisper``). -You can install it with ``python3 -m pip install git+https://github.com/openai/whisper.git soundfile``. +You can install it with ``python3 -m pip install openai-whisper soundfile``. Whisper API (for Whisper API users) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/test_recognition.py b/tests/test_recognition.py index e5937701..2176023c 100644 --- a/tests/test_recognition.py +++ b/tests/test_recognition.py @@ -14,7 +14,6 @@ def setUp(self): self.AUDIO_FILE_EN = os.path.join(os.path.dirname(os.path.realpath(__file__)), "english.wav") self.AUDIO_FILE_FR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "french.aiff") self.AUDIO_FILE_ZH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "chinese.flac") - self.WHISPER_CONFIG = {"temperature": 0} def test_recognizer_attributes(self): r = sr.Recognizer() @@ -81,21 +80,6 @@ def test_ibm_chinese(self): with sr.AudioFile(self.AUDIO_FILE_ZH) as source: audio = r.record(source) self.assertEqual(r.recognize_ibm(audio, username=os.environ["IBM_USERNAME"], password=os.environ["IBM_PASSWORD"], language="zh-CN"), u"砸 自己 的 脚 ") - def test_whisper_english(self): - r = sr.Recognizer() - with sr.AudioFile(self.AUDIO_FILE_EN) as source: audio = r.record(source) - self.assertEqual(r.recognize_whisper(audio, language="english", **self.WHISPER_CONFIG), " 1, 2, 3") - - def test_whisper_french(self): - r = sr.Recognizer() - with sr.AudioFile(self.AUDIO_FILE_FR) as source: audio = r.record(source) - self.assertEqual(r.recognize_whisper(audio, language="french", **self.WHISPER_CONFIG), " et c'est la dictée numéro 1.") - - def test_whisper_chinese(self): - r = sr.Recognizer() - with sr.AudioFile(self.AUDIO_FILE_ZH) as source: audio = r.record(source) - self.assertEqual(r.recognize_whisper(audio, model="small", language="chinese", **self.WHISPER_CONFIG), u"砸自己的腳") - if __name__ == "__main__": unittest.main() diff --git a/tests/test_whisper_recognition.py b/tests/test_whisper_recognition.py new file mode 100644 index 00000000..a0054961 --- /dev/null +++ b/tests/test_whisper_recognition.py @@ -0,0 +1,78 @@ +from unittest import TestCase +from unittest.mock import MagicMock, patch + +import numpy as np + +from speech_recognition import AudioData, Recognizer + + +@patch("speech_recognition.io.BytesIO") +@patch("soundfile.read") +@patch("torch.cuda.is_available") +@patch("whisper.load_model") +class RecognizeWhisperTestCase(TestCase): + def test_default_parameters( + self, load_model, is_available, sf_read, BytesIO + ): + whisper_model = load_model.return_value + transcript = whisper_model.transcribe.return_value + audio_array = MagicMock() + dummy_sampling_rate = 99_999 + sf_read.return_value = (audio_array, dummy_sampling_rate) + + recognizer = Recognizer() + audio_data = MagicMock(spec=AudioData) + actual = recognizer.recognize_whisper(audio_data) + + self.assertEqual(actual, transcript.__getitem__.return_value) + load_model.assert_called_once_with("base") + audio_data.get_wav_data.assert_called_once_with(convert_rate=16000) + BytesIO.assert_called_once_with(audio_data.get_wav_data.return_value) + sf_read.assert_called_once_with(BytesIO.return_value) + audio_array.astype.assert_called_once_with(np.float32) + whisper_model.transcribe.assert_called_once_with( + audio_array.astype.return_value, + language=None, + task=None, + fp16=is_available.return_value, + ) + transcript.__getitem__.assert_called_once_with("text") + + def test_return_as_dict(self, load_model, is_available, sf_read, BytesIO): + whisper_model = load_model.return_value + audio_array = MagicMock() + dummy_sampling_rate = 99_999 + sf_read.return_value = (audio_array, dummy_sampling_rate) + + recognizer = Recognizer() + audio_data = MagicMock(spec=AudioData) + actual = recognizer.recognize_whisper(audio_data, show_dict=True) + + self.assertEqual(actual, whisper_model.transcribe.return_value) + + def test_pass_parameters(self, load_model, is_available, sf_read, BytesIO): + whisper_model = load_model.return_value + transcript = whisper_model.transcribe.return_value + audio_array = MagicMock() + dummy_sampling_rate = 99_999 + sf_read.return_value = (audio_array, dummy_sampling_rate) + + recognizer = Recognizer() + audio_data = MagicMock(spec=AudioData) + actual = recognizer.recognize_whisper( + audio_data, + model="small", + language="english", + translate=True, + temperature=0, + ) + + self.assertEqual(actual, transcript.__getitem__.return_value) + load_model.assert_called_once_with("small") + whisper_model.transcribe.assert_called_once_with( + audio_array.astype.return_value, + language="english", + task="translate", + fp16=is_available.return_value, + temperature=0, + )