Skip to content

Commit

Permalink
🎨 align single process bayes and multiprocess bayes fct
Browse files Browse the repository at this point in the history
- allows easier comparison of functionality
  • Loading branch information
Henry committed Jul 8, 2024
1 parent 2df057b commit 2e87e12
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 45 deletions.
72 changes: 36 additions & 36 deletions src/move/tasks/bayes_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ def _bayes_approach_worker(args):
models_path / f"baseline_recon_{task_config.model.num_latent}_{j}.pt"
)
if reconstruction_path.exists():
logger.debug(
f"Loading baseline reconstruction from {reconstruction_path}, "
"in the worker function"
)
logger.debug(f"Loading baseline reconstruction from {reconstruction_path}.")
baseline_recon = torch.load(reconstruction_path)

logger.debug(f"Loading model {model_path}, using load function")
Expand Down Expand Up @@ -173,36 +170,36 @@ def _bayes_approach_parallel(
num_continuous: int,
nan_mask: BoolArray,
feature_mask: BoolArray,
):
logger.debug("Inside the bayes_parallel function")
) -> tuple[Union[IntArray, FloatArray], ...]:
"""
Calculate Bayes factors for all perturbed features in parallel.
# First, I train or reload the models (number of refits),
# and save the baseline reconstruction.
# We train and get the reconstruction outside to make sure
# that we use the same model and use the same
# baseline reconstruction for all the worker functions
baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
First, I train or reload the models (number of refits), and save the baseline
reconstruction. We train and get the reconstruction outside to make sure
that we use the same model and use the same baseline reconstruction for all
the worker functions.
"""
logger.debug("Inside the bayes_parallel function")

assert task_config.model is not None
device = torch.device("cuda" if task_config.model.cuda else "cpu")
logger.debug("Model moved to device in bayes_approach_parallel")

# Train or reload models
logger.info("Training or reloading models")
# non-perturbed baseline dataset
baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)

for j in range(
task_config.num_refits
): # We create as many models (refits) as indicated in the config file
for j in range(task_config.num_refits):
# We create as many models (refits) as indicated in the config file
# For each j (number of refits) we train a different model, but on the same data
# Initialize model
model: VAE = hydra.utils.instantiate(
task_config.model,
continuous_shapes=baseline_dataset.con_shapes,
categorical_shapes=baseline_dataset.cat_shapes,
)
if (
j == 0
): # First, we see if the models are already created (if we trained them
if j == 0:
# First, we see if the models are already created (if we trained them
# before). for each j, we check if model number j has already been created.
logger.debug(f"Model: {model}")

Expand All @@ -212,9 +209,8 @@ def _bayes_approach_parallel(
)
model_path = models_path / f"model_{task_config.model.num_latent}_{j}.pt"

if (
model_path.exists()
): # If the models were already created, we load them only if we need to get a
if model_path.exists():
# If the models were already created, we load them only if we need to get a
# baseline reconstruction. Otherwise, nothing needs to be done at this point
logger.debug(f"Model {model_path} already exists")
if not reconstruction_path.exists():
Expand All @@ -227,7 +223,8 @@ def _bayes_approach_parallel(
f"Baseline reconstruction for {reconstruction_path} already exists"
f", no need to load model {model_path} "
)
else: # If the models are not created yet, he have to train them, with the
else:
# If the models are not created yet, he have to train them, with the
# parameters we indicated in the config file
logger.debug(f"Training refit {j + 1}/{task_config.num_refits}")
model.to(device)
Expand All @@ -240,6 +237,7 @@ def _bayes_approach_parallel(
if task_config.save_refits:
# pickle_protocol=4 is necessary for very big models
torch.save(model.state_dict(), model_path, pickle_protocol=4)
model.eval()

# Calculate baseline reconstruction
# For each model j, we get a different reconstruction for the baseline.
Expand All @@ -249,9 +247,12 @@ def _bayes_approach_parallel(
# do it inside each process because the results might be different

if reconstruction_path.exists():
logger.debug(f"Baseline reconstruction for model {j} already created")
logger.debug(
f"Loading baseline reconstruction from {reconstruction_path}, "
"in the worker function"
)
# baseline_recon = torch.load(reconstruction_path)
else:
model.eval()
_, baseline_recon = model.reconstruct(baseline_dataloader)

# Save the baseline reconstruction for each model
Expand All @@ -260,13 +261,13 @@ def _bayes_approach_parallel(
logger.debug(f"Saved baseline reconstruction {j}")
del model

# Calculate Bayes factors
logger.info("Identifying significant features")

# Define more arguments that are needed for the worker functions
continuous_shapes = baseline_dataset.con_shapes
categorical_shapes = baseline_dataset.cat_shapes

"""
Perform parallelized bayes approach.
"""
logger.debug("Starting parallelization")

# Define arguments for each worker, and iterate over models and perturbed features
Expand Down Expand Up @@ -304,25 +305,24 @@ def _bayes_approach_parallel(
# (log differences, i.e. Bayes factors)
bayes_k[i, :] = computed_bayes_k
bayes_mask[i, :] = mask_k

# mask already created in worker function
bayes_mask[bayes_mask != 0] = 1
bayes_mask = np.array(bayes_mask, dtype=bool)

# Calculate Bayes probabilities
bayes_abs = np.abs(bayes_k) # Dimensions are (num_perturbed, num_continuous)

bayes_p = np.exp(bayes_abs) / (1 + np.exp(bayes_abs)) # 2D: N x C

bayes_abs[bayes_mask] = np.min(
bayes_abs
) # Bring feature_i feature_i associations to minimum
# Get only the significant associations:
# This will flatten the array,so we get all bayes_abs for all perturbed features
# This will flatten the array, so we get all bayes_abs for all perturbed features
# vs all continuous features in one 1D array
# Then, we sort them, and get the indexes in the flattened array. So, we get an
# list of sorted indexes in the flatenned array
sort_ids = np.argsort(bayes_abs, axis=None)[::-1] # 1D: N x C
logger.debug(f"sort_ids are {sort_ids}")

# bayes_p is the array from which elements will be taken.
# sort_ids contains the indices that determine the order in which elements should
# be taken from bayes_p.
Expand All @@ -332,23 +332,23 @@ def _bayes_approach_parallel(
# elements using the provided indices.
# So, even though sort_ids is obtained from a flattened version of bayes_abs,
# np.take understands how to map these indices
# correctly to the original shape of bayes_p. We get a flattened array?
# correctly to the original shape of bayes_p.
prob = np.take(bayes_p, sort_ids) # 1D: N x C
logger.debug(f"Bayes proba range: [{prob[-1]:.3f} {prob[0]:.3f}]")

# Sort bayes_k in descending order, aligning with the sorted bayes_abs.
bayes_k = np.take(bayes_k, sort_ids) # 1D: N x C

# Calculate FDR
fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1) # 1D ???
fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1) # 1D
idx = np.argmin(np.abs(fdr - task_config.sig_threshold))
logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]")
logger.debug(f"Index is {idx}")
# idx will contain the index of the element in fdr that is closest
# to task_config.sig_threshold.
# This line essentially finds the index where the False Discovery Rate (fdr) is
# closest to the significance threshold
# (task_config.sig_threshold).
logger.debug(f"Index is {idx}")
logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]")

# Return elements only up to idx. They will be the significant findings
# sort_ids[:idx]: Indices of features sorted by significance.
Expand Down
54 changes: 45 additions & 9 deletions src/move/tasks/identify_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,12 @@ def _bayes_approach(
mean_diff = np.zeros((num_perturbed, num_samples, num_continuous))
normalizer = 1 / task_config.num_refits

# Last appended dataloader is the baseline
# non-perturbed baseline dataset
baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)

for j in range(task_config.num_refits):
# We create as many models (refits) as indicated in the config file
# For each j (number of refits) we train a different model, but on the same data
# Initialize model
model: VAE = hydra.utils.instantiate(
task_config.model,
Expand All @@ -243,23 +245,32 @@ def _bayes_approach(
model=model,
train_dataloader=train_dataloader,
)
# Save the refits, to use them later
if task_config.save_refits:
torch.save(model.state_dict(), model_path)
# pickle_protocol=4 is necessary for very big models
torch.save(model.state_dict(), model_path, pickle_protocol=4)
model.eval()

# Calculate baseline reconstruction
# For each model j, we get a different reconstruction for the baseline.
# We haven't perturbed anything yet, we are just
# getting the reconstruction for the baseline, to make sure that we get
# the same reconstruction for each refit, we cannot
# do it inside each process because the results might be different
reconstruction_path = (
models_path / f"baseline_recon_{task_config.model.num_latent}_{j}.pt"
)
if reconstruction_path.exists():
logger.debug(
f"Loading baseline reconstruction from {reconstruction_path}, "
"in the worker function"
)
logger.debug(f"Loading baseline reconstruction from {reconstruction_path}.")
baseline_recon = torch.load(reconstruction_path)
else:
_, baseline_recon = model.reconstruct(baseline_dataloader)

# # Save the baseline reconstruction for each model
# logger.debug(f"Saving baseline reconstruction {j}")
# torch.save(baseline_recon, reconstruction_path, pickle_protocol=4)
# logger.debug(f"Saved baseline reconstruction {j}")

# Calculate perturb reconstruction => keep track of mean difference
for i in range(num_perturbed):
_, perturb_recon = model.reconstruct(dataloaders[i])
Expand All @@ -285,28 +296,53 @@ def _bayes_approach(
bayes_mask = np.array(bayes_mask, dtype=bool)

# Calculate Bayes probabilities
bayes_abs = np.abs(bayes_k)
bayes_abs = np.abs(bayes_k) # Dimensions are (num_perturbed, num_continuous)

bayes_p = np.exp(bayes_abs) / (1 + np.exp(bayes_abs)) # 2D: N x C

bayes_abs[bayes_mask] = np.min(
bayes_abs
) # Bring feature_i feature_i associations to minimum
# Get only the significant associations:
# This will flatten the array, so we get all bayes_abs for all perturbed features
# vs all continuous features in one 1D array
# Then, we sort them, and get the indexes in the flattened array. So, we get an
# list of sorted indexes in the flatenned array
sort_ids = np.argsort(bayes_abs, axis=None)[::-1] # 1D: N x C
logger.debug(f"sort_ids are {sort_ids}")

# bayes_p is the array from which elements will be taken.
# sort_ids contains the indices that determine the order in which elements should
# be taken from bayes_p.
# This operation essentially rearranges the elements of bayes_p based on the
# sorting order specified by sort_ids
# np.take considers the input array as if it were flattened when extracting
# elements using the provided indices.
# So, even though sort_ids is obtained from a flattened version of bayes_abs,
# np.take understands how to map these indices
# correctly to the original shape of bayes_p.
prob = np.take(bayes_p, sort_ids) # 1D: N x C
logger.debug(f"Bayes proba range: [{prob[-1]:.3f} {prob[0]:.3f}]")

# Sort Bayes
# Sort bayes_k in descending order, aligning with the sorted bayes_abs.
bayes_k = np.take(bayes_k, sort_ids) # 1D: N x C

# Calculate FDR
fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1) # 1D
idx = np.argmin(np.abs(fdr - task_config.sig_threshold))
# idx will contain the index of the element in fdr that is closest
# to task_config.sig_threshold.
# This line essentially finds the index where the False Discovery Rate (fdr) is
# closest to the significance threshold
# (task_config.sig_threshold).
logger.debug(f"Index is {idx}")
logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]")

# Return elements only up to idx. They will be the significant findings
# sort_ids[:idx]: Indices of features sorted by significance.
# prob[:idx]: Probabilities of significant associations for selected features.
# fdr[:idx]: False Discovery Rate values for selected features.
# bayes_k[:idx]: Bayes Factors indicating the strength of evidence for selected
# associations.
return sort_ids[:idx], prob[:idx], fdr[:idx], bayes_k[:idx]


Expand Down

0 comments on commit 2e87e12

Please sign in to comment.