Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
NumericalMax committed Jan 23, 2024
1 parent 832d4fd commit 9dadcd6
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 274 deletions.
409 changes: 197 additions & 212 deletions analysis/article.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/evaluate/personalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fine_tune_evaluate(self, dataset, batch_size=256):
#model_fine_tune = tf.keras.models.clone_model(self._model)
#model_fine_tune.compile(optimizer=RMSprop(learning_rate=0.001))
self._model.fit(
Helper.data_generator(train),
Helper.data_generator([train]),
steps_per_epoch=len(train),
epochs=epochs,
callbacks=CSVLogger(path + '/training_progress.csv'),
Expand Down
15 changes: 8 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,20 @@ def main(parameters):
######################################################
# DATA LOADING
######################################################
train, size_train = Helper.load_dataset(parameters['train_dataset'])
val, _ = Helper.load_dataset(parameters['val_dataset'])
#train, size_train = Helper.load_dataset(parameters['train_dataset'])
train, size_train = Helper.load_multiple_datasets(parameters['train_dataset'])
val, size_val = Helper.load_multiple_datasets(parameters['val_dataset'])

######################################################
# MACHINE LEARNING
######################################################
callbacks = [
ReduceLROnPlateau(monitor='recon', factor=0.05, patience=50, min_lr=0.000001),
ReduceLROnPlateau(monitor='recon', factor=0.05, patience=20, min_lr=0.000001),
TerminateOnNaN(),
CSVLogger(base_path + 'training/training_progress.csv'),
CoefficientScheduler(parameters['epochs'], parameters['coefficients']),
#ModelCheckpoint(filepath=base_path + 'model/', monitor='loss', save_best_only=True, verbose=0),
ReconstructionPlot(train, parameters['index_tracked_sample'], base_path + 'training/reconstruction/',
ModelCheckpoint(filepath=base_path + 'model/', monitor='loss', save_best_only=True, verbose=0),
ReconstructionPlot(train[0], parameters['index_tracked_sample'], base_path + 'training/reconstruction/',
period=parameters['period_reconstruction_plot']),
#CollapseCallback(val),
EarlyStopping(monitor="val_loss", patience=parameters['early_stopping'])
Expand All @@ -58,8 +59,8 @@ def main(parameters):
vae = TCVAE(encoder, decoder, parameters['coefficients'], size_train)
vae.compile(optimizer=RMSprop(learning_rate=parameters['learning_rate']))
vae.fit(
Helper.data_generator(train), steps_per_epoch=len(train),
validation_data=Helper.data_generator(val), validation_steps=len(val),
Helper.data_generator(train), steps_per_epoch=size_train,
validation_data=Helper.data_generator(val), validation_steps=size_val,
epochs=parameters['epochs'], callbacks=callbacks, verbose=1,
)

Expand Down
21 changes: 4 additions & 17 deletions src/model/tcvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,45 +34,32 @@ def gamma(self):
def gamma(self, value):
self._gamma.assign(value)


def _loss(self, reconstruction, x, mu, log_var, z):
size_batch = tf.shape(x)[0]
logiw_mat = self.log_importance_weight_matrix_iso(size_batch)
logiw_mat = self.log_importance_weight_matrix(size_batch)

recon_loss = tf.reduce_sum(tf.keras.losses.mean_absolute_error(x, reconstruction))
# log(q(z|x))
log_qz_x = tf.reduce_sum(self._dist(mu, tf.exp(log_var)).log_prob(z), axis=-1)
#log_qz_x = tf.reduce_sum(self.log_normal_pdf(z, mu, log_var), axis=-1)

# log(p(z))
log_prior = tf.reduce_sum(self._dist(tf.zeros_like(z), tf.ones_like(z)).log_prob(z), axis=-1)
#log_prior = tf.reduce_sum(
# self.log_normal_pdf(
# z,
# tf.zeros_like(z),
# tf.zeros_like(z),
# ), axis=-1)

# log(q(z(x_j) | x_i))
log_qz_prob = self._dist(
tf.expand_dims(mu, 0), tf.expand_dims(tf.exp(log_var), 0),
).log_prob(tf.expand_dims(z, 1))
#log_qz_prob = self.log_normal_pdf(
# tf.expand_dims(z, 1),
# tf.expand_dims(mu, 0),
# tf.expand_dims(log_var, 0),
#)

# Weighting as we are not calculating q based on the complete dataset but only a batch
#log_qz_prob = log_qz_prob + tf.expand_dims(logiw_mat, 2)

# log(q(z))
# we can simply sum due to the assumption of independent factors
# we can sum due to the assumption of independent factors
log_qz = tf.reduce_logsumexp(logiw_mat + tf.reduce_sum(log_qz_prob, axis=2), axis=1)
# log(PI_i q(z_i))
log_qz_product = tf.reduce_sum(tf.reduce_logsumexp(tf.expand_dims(logiw_mat, 2) + log_qz_prob, axis=1), axis=1)

# I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
mutual_info_loss = tf.reduce_mean(tf.subtract(log_qz_x, log_qz))

# TC[z] = KL[q(z)||\prod_i z_i]
tc_loss = tf.reduce_mean(tf.subtract(log_qz, log_qz_product))
# dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))]
Expand Down
17 changes: 11 additions & 6 deletions src/params.yml
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
train_dataset:
name: 'zheng'
name:
- synth
- zheng
- ptb
- medalcare
split: 'train'
shuffle_size: 1024
batch_size: 1024
val_dataset:
name: 'zheng'
name:
- zheng
split: 'train'
shuffle_size: 1024
batch_size: 1024
save_results_path: ../results/
seed: 42
epochs: 50
latent_dimension: 3
epochs: 200
latent_dimension: 8
learning_rate: 0.001
coefficients:
alpha: 1.0
beta: 3.0
gamma: 3.0
beta: 64.0
gamma: 1.0
coefficients_raise: 20
early_stopping: 50000
period_reconstruction_plot: 20
Expand Down
34 changes: 6 additions & 28 deletions src/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,37 +53,15 @@ def __init__(self, data, aggressive=True):
super().__init__()
self._aggressive = aggressive
self._data = data
self._temp = -np.inf
iterator = iter(data)
batch = next(iterator)
self._last_value = -np.inf
batch = next(iter(data))
self._ecg = batch['ecg']['I']

def on_epoch_begin(self, epoch, logs=None):
res = self.model.compute_information_gain(self._ecg)
self.aggressive = res >= self._temp
if self.aggressive:
if self.model._decoder.trainable:
self.model._decoder.trainable = False
else:
self.model._encoder.trainable = True
self.model._decoder.trainable = False
self._temp = res
aggr = res >= self._last_value
if aggr:
self.model._decoder.trainable = False
else:
self.model._encoder.trainable = True
self.model._decoder.trainable = True

def on_epoch_end(self, epoch, logs=None):

res = self.model.compute_information_gain(self._ecg)
tf.print(res)
self._aggressive = res >= self._temp
if self.aggressive:
if self.model._decoder.trainable:
self.model._decoder.trainable = False
else:
self.model._encoder.trainable = True
self.model._decoder.trainable = False
self.temp = res
else:
self.model._encoder.trainable = True
self.model._decoder.trainable = True
self._last_value = res
23 changes: 20 additions & 3 deletions src/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@ def generate_paths(paths):

@staticmethod
def data_generator(dataset, method='continue'):
iterator = iter(dataset)
k = 0
n = len(dataset)
iterator = iter(dataset[k])
while True:
try:
batch = next(iterator)
yield batch['ecg']['I']
except StopIteration:
if method == 'continue':
iterator = iter(dataset)
if k == n - 1:
k = 0
else:
k = k + 1
iterator = iter(dataset[k])
elif method == 'stop':
return

Expand Down Expand Up @@ -83,7 +89,7 @@ def get_embedding(model, dataset, split='train', save_path=None, batch_size=512)
data_train = tfds.load(dataset, split=[split])
train = data_train[0].batch(batch_size).prefetch(tf.data.AUTOTUNE)
labels = Helper.get_labels(train)
z_mean, z_log_var = model._encoder.predict(Helper.data_generator(train, method='stop'))
z_mean, z_log_var = model._encoder.predict(Helper.data_generator([train], method='stop'))
z = model.reparameterize(z_mean, z_log_var)

z_mean = np.expand_dims(z_mean, axis=2)
Expand Down Expand Up @@ -115,6 +121,17 @@ def load_embedding(path, dataset, split):
df = pd.concat([df, y], axis=1)
return df, latent_dim

@staticmethod
def load_multiple_datasets(datasets):
size = 0
data_list = []
for i, k in enumerate(datasets['name']):
temp = tfds.load(k, split=[datasets['split']], shuffle_files=True)
data = temp[0].shuffle(datasets['shuffle_size']).batch(datasets['batch_size']).prefetch(tf.data.AUTOTUNE)
size = size + len(data)
data_list.append(data)
return data_list, size

@staticmethod
def load_dataset(dataset):
temp = tfds.load(dataset['name'],split=[dataset['split']], shuffle_files=True)
Expand Down

0 comments on commit 9dadcd6

Please sign in to comment.