-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmytrain.py
52 lines (42 loc) · 1.5 KB
/
mytrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import tqdm
import argparse
from dn3_ext import BendingCollegeWav2Vec, ConvEncoderBENDR, BENDRContextualizer
from dn3.transforms.batch import RandomTemporalCrop
from torch.utils.data import TensorDataset, DataLoader
import mne
import pickle
import numpy as np
mne.set_log_level(False)
# Dataset
# load the pickle file
with open('MindMNIST.pkl', 'rb') as f:
arrays = pickle.load(f)
z_arrays = pickle.load(f)
digit_labels = pickle.load(f)
digit_labels = np.array(digit_labels)
arrays = np.stack(z_arrays, axis=0)
# test on 1 and 5
valid_data = (digit_labels == 1) | (digit_labels == 5)
X = arrays[valid_data, :]
y = digit_labels[valid_data]
y = y // np.max(y)
X = np.tile(X, 10)
# Convert to torch tensors
X_train = torch.tensor(X, dtype=torch.float32)
y_train = torch.tensor(y, dtype=torch.int64)
# Create a Dataset and DataLoader
training_dataset = TensorDataset(X_train, y_train)
# train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Model
encoder = ConvEncoderBENDR(4, encoder_h=512)
contextualizer = BENDRContextualizer(encoder.encoder_h)
process = BendingCollegeWav2Vec(encoder, contextualizer)
# Slower learning rate for the encoder
process.set_optimizer(torch.optim.Adam(process.parameters()))
process.add_batch_transform(RandomTemporalCrop())
process.fit(training_dataset, epochs=1, num_workers=0)
# print(process.evaluate(training_dataset))
tqdm.tqdm.write("Saving last model...")
encoder.save(f'checkpoints/my_encoder.pt')
contextualizer.save(f'checkpoints/my_contextualizer.pt')