Skip to content

Commit

Permalink
Added input from audio + text node.
Browse files Browse the repository at this point in the history
  • Loading branch information
niknah committed Nov 5, 2024
1 parent 6288654 commit 62e4d7b
Show file tree
Hide file tree
Showing 5 changed files with 426 additions and 83 deletions.
238 changes: 161 additions & 77 deletions F5TTS.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
import numpy as np
import re
import io
from comfy.utils import ProgressBar
from cached_path import cached_path
sys.path.append(Install.f5TTSPath)
Expand All @@ -24,60 +25,32 @@
sys.path.pop()


class F5TTSAudio:

def __init__(self):
self.use_cli = False
self.voice_reg = re.compile(r"\{(\w+)\}")
class F5TTSCreate:
voice_reg = re.compile(r"\{(\w+)\}")

@staticmethod
def get_txt_file_path(file):
p = Path(file)
return os.path.join(os.path.dirname(file), p.stem + ".txt")
def is_voice_name(self, word):
return self.voice_reg.match(word.strip())

@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = folder_paths.filter_files_content_types(
os.listdir(input_dir), ["audio", "video"]
)
filesWithTxt = []
for file in files:
txtFile = F5TTSAudio.get_txt_file_path(file)
if os.path.isfile(os.path.join(input_dir, txtFile)):
filesWithTxt.append(file)
return {
"required": {
"sample": (sorted(filesWithTxt), {"audio_upload": True}),
"speech": ("STRING", {
"multiline": True,
"default": "Hello World"
}),
}
}
def get_voice_names(self, chunks):
voice_names = {}
for text in chunks:
match = self.is_voice_name(text)
if match:
voice_names[match[1]] = True
return voice_names

CATEGORY = "audio"
def split_text(self, speech):
reg1 = r"(?=\{\w+\})"
return re.split(reg1, speech)

RETURN_TYPES = ("AUDIO", )
FUNCTION = "create"
@staticmethod
def load_voice(ref_audio, ref_text):
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}

def create_with_cli(self, audio_path, audio_text, speech, output_dir):
subprocess.run(
[
"python", "inference-cli.py", "--model", "F5-TTS",
"--ref_audio", audio_path, "--ref_text", audio_text,
"--gen_text", speech,
"--output_dir", output_dir
],
cwd=Install.f5TTSPath
main_voice["ref_audio"], main_voice["ref_text"] = preprocess_ref_audio_text( # noqa E501
ref_audio, ref_text
)
output_audio = os.path.join(output_dir, "out.wav")
with wave.open(output_audio, "rb") as wave_file:
frame_rate = wave_file.getframerate()

waveform, sample_rate = torchaudio.load(output_audio)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": frame_rate}
return audio
return main_voice

def load_model(self):
model_cls = DiT
Expand All @@ -95,29 +68,6 @@ def load_model(self):
ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
return ema_model

def load_voice(self, ref_audio, ref_text):
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}

main_voice["ref_audio"], main_voice["ref_text"] = preprocess_ref_audio_text( # noqa E501
ref_audio, ref_text
)
return main_voice

def is_voice_name(self, word):
return self.voice_reg.match(word.strip())

def get_voice_names(self, chunks):
voice_names = {}
for text in chunks:
match = self.is_voice_name(text)
if match:
voice_names[match[1]] = True
return voice_names

def split_text(self, speech):
reg1 = r"(?=\{\w+\})"
return re.split(reg1, speech)

def generate_audio(self, voices, model_obj, chunks):
frame_rate = 44100
generated_audio_segments = []
Expand All @@ -133,7 +83,7 @@ def generate_audio(self, voices, model_obj, chunks):
if voice not in voices:
print(f"Voice {voice} not found, using main.")
voice = "main"
text = self.voice_reg.sub("", text)
text = F5TTSCreate.voice_reg.sub("", text)
gen_text = text.strip()
ref_audio = voices[voice]["ref_audio"]
ref_text = voices[voice]["ref_text"]
Expand All @@ -160,6 +110,137 @@ def generate_audio(self, voices, model_obj, chunks):
os.unlink(wave_file.name)
return audio

def create(self, voices, chunks):
model_obj = self.load_model()
return self.generate_audio(voices, model_obj, chunks)


class F5TTSAudioInputs:
def __init__(self):
self.wave_file = None

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"sample_audio": ("AUDIO",),
"sample_text": ("STRING", {"default": "Text of sample_audio"}),
"speech": ("STRING", {
"multiline": True,
"default": "This is what I want to say"
}),
},
}

CATEGORY = "audio"

RETURN_TYPES = ("AUDIO", )
FUNCTION = "create"

def load_voice_from_input(self, sample_audio, sample_text):
self.wave_file = tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
)
for (batch_number, waveform) in enumerate(
sample_audio["waveform"].cpu()):
buff = io.BytesIO()
torchaudio.save(
buff, waveform, sample_audio["sample_rate"], format="WAV"
)
with open(self.wave_file.name, 'wb') as f:
f.write(buff.getbuffer())
break
r = F5TTSCreate.load_voice(self.wave_file.name, sample_text)
return r

def remove_wave_file(self):
if self.wave_file is not None:
try:
os.unlink(self.wave_file.name)
self.wave_file = None
except Exception as e:
print("F5TTS: Cannot remove? "+self.wave_file.name)
print(e)

def create(self, sample_audio, sample_text, speech):
try:
main_voice = self.load_voice_from_input(sample_audio, sample_text)

f5ttsCreate = F5TTSCreate()

voices = {}
chunks = f5ttsCreate.split_text(speech)
voices['main'] = main_voice

audio = f5ttsCreate.create(voices, chunks)
finally:
self.remove_wave_file()
return (audio, )

@classmethod
def IS_CHANGED(s, sample_audio, sample_text, speech):
m = hashlib.sha256()
m.update(sample_text)
m.update(sample_audio)
m.update(speech)
return m.digest().hex()


class F5TTSAudio:
def __init__(self):
self.use_cli = False

@staticmethod
def get_txt_file_path(file):
p = Path(file)
return os.path.join(os.path.dirname(file), p.stem + ".txt")

@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = folder_paths.filter_files_content_types(
os.listdir(input_dir), ["audio", "video"]
)
filesWithTxt = []
for file in files:
txtFile = F5TTSAudio.get_txt_file_path(file)
if os.path.isfile(os.path.join(input_dir, txtFile)):
filesWithTxt.append(file)
filesWithTxt = sorted(filesWithTxt)

return {
"required": {
"sample": (filesWithTxt, {"audio_upload": True}),
"speech": ("STRING", {
"multiline": True,
"default": "This is what I want to say"
}),
}
}

CATEGORY = "audio"

RETURN_TYPES = ("AUDIO", )
FUNCTION = "create"

def create_with_cli(self, audio_path, audio_text, speech, output_dir):
subprocess.run(
[
"python", "inference-cli.py", "--model", "F5-TTS",
"--ref_audio", audio_path, "--ref_text", audio_text,
"--gen_text", speech,
"--output_dir", output_dir
],
cwd=Install.f5TTSPath
)
output_audio = os.path.join(output_dir, "out.wav")
with wave.open(output_audio, "rb") as wave_file:
frame_rate = wave_file.getframerate()

waveform, sample_rate = torchaudio.load(output_audio)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": frame_rate}
return audio

def load_voice_from_file(self, sample):
input_dir = folder_paths.get_input_directory()
txt_file = os.path.join(
Expand All @@ -170,7 +251,7 @@ def load_voice_from_file(self, sample):
with open(txt_file, 'r') as file:
audio_text = file.read()
audio_path = folder_paths.get_annotated_filepath(sample)
return self.load_voice(audio_path, audio_text)
return F5TTSCreate.load_voice(audio_path, audio_text)

def load_voices_from_files(self, sample, voice_names):
voices = {}
Expand All @@ -194,6 +275,7 @@ def create(self, sample, speech):
# Install.check_install()
main_voice = self.load_voice_from_file(sample)

f5ttsCreate = F5TTSCreate()
if self.use_cli:
# working...
output_dir = tempfile.mkdtemp()
Expand All @@ -204,21 +286,23 @@ def create(self, sample, speech):
)
shutil.rmtree(output_dir)
else:
model_obj = self.load_model()
chunks = self.split_text(speech)
voice_names = self.get_voice_names(chunks)
chunks = f5ttsCreate.split_text(speech)
voice_names = f5ttsCreate.get_voice_names(chunks)
voices = self.load_voices_from_files(sample, voice_names)
voices['main'] = main_voice

audio = self.generate_audio(voices, model_obj, chunks)
audio = f5ttsCreate.create(voices, chunks)
return (audio, )

@classmethod
def IS_CHANGED(s, sample, speech):
m = hashlib.sha256()
audio_path = folder_paths.get_annotated_filepath(sample)
audio_txt_path = F5TTSAudio.get_txt_file_path(audio_path)
last_modified_timestamp = os.path.getmtime(audio_path)
txt_last_modified_timestamp = os.path.getmtime(audio_txt_path)
m.update(audio_path)
m.update(str(last_modified_timestamp))
m.update(str(txt_last_modified_timestamp))
m.update(speech)
return m.digest().hex()
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ Using F5-TTS https://github.com/SWivid/F5-TTS
* Press refresh to see it in the node

You can use the examples here...
* [examples voices](examples/)
* [simple workflow](examples/simple_ComfyUI_F5TTS_workflow.json)
* [Examples voices](examples/)
* [Simple workflow](examples/simple_ComfyUI_F5TTS_workflow.json)
* [Workflow with input audio only, using OpenAI's Whisper to get the text](examples/F5TTS_whisper_workflow.json)


### Multi voices...
Expand Down
8 changes: 5 additions & 3 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@

from .F5TTS import F5TTSAudio
from .F5TTS import F5TTSAudio, F5TTSAudioInputs

NODE_CLASS_MAPPINGS = {
"F5TTSAudio": F5TTSAudio
"F5TTSAudio": F5TTSAudio,
"F5TTSAudioInputs": F5TTSAudioInputs
}
NODE_DISPLAY_NAME_MAPPINGS = {
"F5TTSAudio": "F5-TTS Audio"
"F5TTSAudio": "F5-TTS Audio",
"F5TTSAudioInputs": "F5-TTS Audio from inputs"
}
Loading

0 comments on commit 62e4d7b

Please sign in to comment.