-
Notifications
You must be signed in to change notification settings - Fork 5
/
test_aves.py
46 lines (33 loc) · 1.2 KB
/
test_aves.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import fairseq
import torch
import torch.nn as nn
class AvesClassifier(nn.Module):
def __init__(self, model_path, num_classes, embeddings_dim=768, multi_label=False):
super().__init__()
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_path])
self.model = models[0]
self.model.feature_extractor.requires_grad_(False)
self.head = nn.Linear(in_features=embeddings_dim, out_features=num_classes)
if multi_label:
self.loss_func = nn.BCEWithLogitsLoss()
else:
self.loss_func = nn.CrossEntropyLoss()
def forward(self, x, y=None):
out = self.model.extract_features(x)[0]
out = out.mean(dim=1) # mean pooling
logits = self.head(out)
loss = None
if y is not None:
loss = self.loss_func(logits, y)
return loss, logits
# Initialize an AVES classifier with 10 target classes
model = AvesClassifier(
model_path='./aves-base-bio.pt',
num_classes=10)
model.eval()
# Create a 1-second random sound
waveform = torch.rand((16_000))
x = waveform.unsqueeze(0)
y = torch.tensor([0])
# Run the forward pass
loss, logits = model(x, y)