Skip to content

Commit

Permalink
Added new kwarg to Speech service (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
osolmaz authored Jan 7, 2023
1 parent 609b019 commit e2a8826
Show file tree
Hide file tree
Showing 3 changed files with 1,436 additions and 1,407 deletions.
37 changes: 30 additions & 7 deletions manim_voiceover/services/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,20 @@ def __init__(
global_speed: float = 1.00,
cache_dir: str = None,
transcription_model: str = None,
transcription_kwargs: dict = {},
**kwargs
):
"""
Args:
global_speed (float, optional): The speed at which to play the audio. Defaults to 1.00.
cache_dir (str, optional): The directory to save the audio files to. Defaults to ``voiceovers/``.
transcription_model (str, optional): The `OpenAI Whisper model <https://github.com/openai/whisper#available-models-and-languages>`_ to use for transcription. Defaults to None.
global_speed (float, optional): The speed at which to play the audio.
Defaults to 1.00.
cache_dir (str, optional): The directory to save the audio
files to. Defaults to ``voiceovers/``.
transcription_model (str, optional): The
`OpenAI Whisper model <https://github.com/openai/whisper#available-models-and-languages>`_
to use for transcription. Defaults to None.
transcription_kwargs (dict, optional): Keyword arguments to
pass to the transcribe() function. Defaults to {}.
"""
self.global_speed = global_speed

Expand All @@ -65,11 +72,11 @@ def __init__(
if not os.path.exists(self.cache_dir):
os.makedirs(self.cache_dir)

self.transcription_model = None
self._whisper_model = None
self.transcription_model = transcription_model
self.set_transcription(model=transcription_model, kwargs=transcription_kwargs)

if self.transcription_model is not None:
self._whisper_model = get_whisper_model(self.transcription_model)
self.additional_kwargs = kwargs

def _wrap_generate_from_text(self, text: str, path: str = None, **kwargs) -> dict:
# Replace newlines with lines, reduce multiple consecutive spaces to single
Expand All @@ -81,7 +88,7 @@ def _wrap_generate_from_text(self, text: str, path: str = None, **kwargs) -> dic
# Check whether word boundaries exist and if not run stt
if "word_boundaries" not in dict_ and self._whisper_model is not None:
transcription_result = self._whisper_model.transcribe(
str(Path(self.cache_dir) / original_audio)
str(Path(self.cache_dir) / original_audio), **self.transcription_kwargs
)
logger.info("Transcription: " + transcription_result["text"])
word_boundaries = timestamps_to_word_boundaries(
Expand Down Expand Up @@ -116,6 +123,22 @@ def _wrap_generate_from_text(self, text: str, path: str = None, **kwargs) -> dic
)
return dict_

def set_transcription(self, model: str = None, kwargs: dict = {}):
"""Set the transcription model and keyword arguments to be passed
to the transcribe() function.
Args:
model (str, optional): The Whisper model to use for transcription. Defaults to None.
kwargs (dict, optional): Keyword arguments to pass to the transcribe() function. Defaults to {}.
"""
if model != self.transcription_model:
if model is not None:
self._whisper_model = get_whisper_model(model)
else:
self._whisper_model = None

self.transcription_kwargs = kwargs

def get_data_hash(self, data: dict) -> str:
dumped_data = json.dumps(data)
data_hash = hashlib.sha256(dumped_data.encode("utf-8")).hexdigest()
Expand Down
Loading

0 comments on commit e2a8826

Please sign in to comment.