diff --git a/F5TTS.py b/F5TTS.py index 11c3683..569e321 100644 --- a/F5TTS.py +++ b/F5TTS.py @@ -17,7 +17,7 @@ from comfy.utils import ProgressBar from cached_path import cached_path sys.path.append(Install.f5TTSPath) -from model import DiT # noqa E402 +from model import DiT,UNetT # noqa E402 from model.utils_infer import ( # noqa E402 load_model, preprocess_ref_audio_text, @@ -28,6 +28,7 @@ class F5TTSCreate: voice_reg = re.compile(r"\{(\w+)\}") + model_types = ["F5", "E2"] tooltip_seed = "Seed. -1 = random" def is_voice_name(self, word): @@ -54,7 +55,33 @@ def load_voice(ref_audio, ref_text): ) return main_voice - def load_model(self): + def load_model(self, model): + models = { + "F5": self.load_f5_model, + "E2": self.load_e2_model, + } + return models[model]() + + def get_vocab_file(self): + return os.path.join( + Install.f5TTSPath, "data/Emilia_ZH_EN_pinyin/vocab.txt" + ) + + def load_e2_model(self): + model_cls = UNetT + model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) + repo_name = "E2-TTS" + exp_name = "E2TTS_Base" + ckpt_step = 1200000 + ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # noqa E501 + vocab_file = self.get_vocab_file() + ema_model = load_model( + model_cls, model_cfg, + ckpt_file, vocab_file + ) + return ema_model + + def load_f5_model(self): model_cls = DiT model_cfg = dict( dim=1024, depth=22, heads=16, @@ -64,10 +91,11 @@ def load_model(self): exp_name = "F5TTS_Base" ckpt_step = 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) # noqa E501 - vocab_file = os.path.join( - Install.f5TTSPath, "data/Emilia_ZH_EN_pinyin/vocab.txt" + vocab_file = self.get_vocab_file() + ema_model = load_model( + model_cls, model_cfg, + ckpt_file, vocab_file ) - ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file) return ema_model def generate_audio(self, voices, model_obj, chunks, seed): @@ -117,8 +145,8 @@ def generate_audio(self, voices, model_obj, chunks, seed): os.unlink(wave_file.name) return audio - def create(self, voices, chunks, seed=-1): - model_obj = self.load_model() + def create(self, voices, chunks, seed=-1, model="F5"): + model_obj = self.load_model(model) return self.generate_audio(voices, model_obj, chunks, seed) @@ -141,6 +169,7 @@ def INPUT_TYPES(s): "default": 1, "min": -1, "tooltip": F5TTSCreate.tooltip_seed, }), + "model": (F5TTSCreate.model_types,), }, } @@ -174,7 +203,7 @@ def remove_wave_file(self): print("F5TTS: Cannot remove? "+self.wave_file.name) print(e) - def create(self, sample_audio, sample_text, speech, seed=-1): + def create(self, sample_audio, sample_text, speech, seed=-1, model="F5"): try: main_voice = self.load_voice_from_input(sample_audio, sample_text) @@ -184,7 +213,7 @@ def create(self, sample_audio, sample_text, speech, seed=-1): chunks = f5ttsCreate.split_text(speech) voices['main'] = main_voice - audio = f5ttsCreate.create(voices, chunks, seed) + audio = f5ttsCreate.create(voices, chunks, seed, model) finally: self.remove_wave_file() return (audio, ) @@ -233,6 +262,7 @@ def INPUT_TYPES(s): "default": 1, "min": -1, "tooltip": F5TTSCreate.tooltip_seed, }), + "model": (F5TTSCreate.model_types,), } } @@ -289,7 +319,7 @@ def load_voices_from_files(self, sample, voice_names): voices[voice_name] = self.load_voice_from_file(sample_file) return voices - def create(self, sample, speech, seed=-1): + def create(self, sample, speech, seed=-1, model="F5"): # Install.check_install() main_voice = self.load_voice_from_file(sample) @@ -309,7 +339,7 @@ def create(self, sample, speech, seed=-1): voices = self.load_voices_from_files(sample, voice_names) voices['main'] = main_voice - audio = f5ttsCreate.create(voices, chunks, seed) + audio = f5ttsCreate.create(voices, chunks, seed, model) return (audio, ) @classmethod diff --git a/pyproject.toml b/pyproject.toml index 4e4a09c..3e246fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-f5-tts" description = "Text to speech with F5-TTS" -version = "1.0.4" +version = "1.0.5" license = {text = "MIT License"} [project.urls]