forked from Plachtaa/VALL-E-X
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
61 lines (55 loc) · 2.24 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import torch
import logging
from data.dataset import create_dataloader
from macros import *
from data.tokenizer import (
AudioTokenizer,
tokenize_audio,
)
from data.collation import get_text_token_collater
from models.vallex import VALLE
if torch.cuda.is_available():
device = torch.device("cuda", 0)
from vocos import Vocos
def get_model(device):
url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'
checkpoints_dir = "./checkpoints"
model_checkpoint_name = "vallex-checkpoint_modified.pt"
if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
import wget
print("3")
try:
logging.info(
"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...")
# download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt
wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive)
except Exception as e:
logging.info(e)
raise Exception(
"\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
"\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints"))
# VALL-E
model = VALLE(
N_DIM,
NUM_HEAD,
NUM_LAYERS,
norm_first=True,
add_prenet=False,
prefix_mode=PREFIX_MODE,
share_embedding=True,
nar_scale_factor=1.0,
prepend_bos=True,
num_quantizers=NUM_QUANTIZERS,
).to(device)
checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True
)
assert not missing_keys
# Encodec
codec = AudioTokenizer(device)
vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
return model, codec, vocos