diff --git a/benchmark_utils/__init__.py b/benchmark_utils/__init__.py index bef2bb3210..81c8a2bd0b 100644 --- a/benchmark_utils/__init__.py +++ b/benchmark_utils/__init__.py @@ -37,7 +37,7 @@ map_hgnc_to_ensg, ) from .sanity_checks_utils import ( - run_purified_sanity_check, + # run_purified_sanity_check, run_sanity_check, ) from .tuning_utils import( diff --git a/benchmark_utils/dataset_utils.py b/benchmark_utils/dataset_utils.py index 3c01da76a3..599092bd79 100644 --- a/benchmark_utils/dataset_utils.py +++ b/benchmark_utils/dataset_utils.py @@ -14,22 +14,33 @@ def preprocess_scrna( adata: ad.AnnData, keep_genes: int = 2000, batch_key: Optional[str] = None ): - """Preprocess single-cell RNA data for deconvolution benchmarking.""" + """Preprocess single-cell RNA data for deconvolution benchmarking. + + * in adata.X, the normalized log1p counts are saved + * in adata.layers["counts"], raw counts are saved + * in adata.layers["relative_counts"], the relative counts are saved + => The highly variable genes can be found in adata.var["highly_variable"] + + """ sc.pp.filter_genes(adata, min_counts=3) adata.layers["counts"] = adata.X.copy() # preserve counts, used for training sc.pp.normalize_total(adata, target_sum=1e4) - adata.layers["relative_counts"] = adata.X.copy() # preserve counts, used for + adata.layers["relative_counts"] = adata.X.copy() # preserve counts, used for baselines sc.pp.log1p(adata) adata.raw = adata # freeze the state in `.raw` sc.pp.highly_variable_genes( adata, n_top_genes=keep_genes, - subset=True, layer="counts", flavor="seurat_v3", batch_key=batch_key, + subset=False, + inplace=True ) #TODO: add the filtering / QC steps that they perform in Servier + # concat the result df to adata.var + + return adata def split_dataset( diff --git a/benchmark_utils/deconv_utils.py b/benchmark_utils/deconv_utils.py index c5f5fa9f08..199998ea11 100644 --- a/benchmark_utils/deconv_utils.py +++ b/benchmark_utils/deconv_utils.py @@ -47,6 +47,7 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData, scvi.model.MixUpVI, scvi.model.CondSCVI]], all_adata_samples, + filtered_genes, use_mixupvi: bool = True, use_nnls: bool = True, use_softmax: bool = False) -> pd.DataFrame: @@ -74,7 +75,7 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData, if use_mixupvi: latent_pseudobulks=[] for i in range(len(all_adata_samples)): - latent_pseudobulks.append(model.get_latent_representation(all_adata_samples[i], get_pseudobulk=True)) + latent_pseudobulks.append(model.get_latent_representation(all_adata_samples[i,filtered_genes], get_pseudobulk=True)) latent_pseudobulk = np.concatenate(latent_pseudobulks, axis=0) else: adata_pseudobulk = ad.AnnData(X=adata_pseudobulk.layers["counts"], @@ -82,7 +83,7 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData, var=adata_pseudobulk.var) adata_pseudobulk.layers["counts"] = adata_pseudobulk.X.copy() - latent_pseudobulk = model.get_latent_representation(adata_pseudobulk) + latent_pseudobulk = model.get_latent_representation(adata_pseudobulk, get_pseudobulk=False) if use_nnls: deconv = LinearRegression(positive=True).fit(adata_latent_signature.X.T, diff --git a/benchmark_utils/latent_signature_utils.py b/benchmark_utils/latent_signature_utils.py index 848e9e0a6f..08b813aec3 100644 --- a/benchmark_utils/latent_signature_utils.py +++ b/benchmark_utils/latent_signature_utils.py @@ -8,6 +8,7 @@ import torch from .dataset_utils import create_anndata_pseudobulk +from constants import SIGNATURE_TYPE def create_latent_signature( @@ -16,7 +17,6 @@ def create_latent_signature( repeats: int = 1, average_all_cells: bool = True, sc_per_pseudobulk: int = 3000, - signature_type: str = "pre-encoded", cell_type_column: str = "cell_types_grouped", count_key: Optional[str] = "counts", representation_key: Optional[str] = "X_scvi", @@ -101,34 +101,34 @@ def create_latent_signature( ) adata_sampled = adata[sampled_cells] - if signature_type == "pre-encoded": - assert ( - model is not None, - "If representing a purified pseudo bulk (aggregate before embedding", - "), must give a model", - ) - assert ( - count_key is not None - ), "Must give a count key if aggregating before embedding." - - if use_mixupvi: - result = model.get_latent_representation( - adata_sampled, get_pseudobulk=True - ).reshape(-1) - else: + assert ( + model is not None, + "If representing a purified pseudo bulk (aggregate before embedding", + "), must give a model", + ) + assert ( + count_key is not None + ), "Must give a count key if aggregating before embedding." + + if use_mixupvi: + # TODO: in this case, n_cells sampled will be equal to self.n_cells_per_pseudobulk by mixupvae + # so change that to being equal to either all cells (if average_all_cells) or sc_per_pseudobulk + result = model.get_latent_representation( + adata_sampled, get_pseudobulk=True + )[0] # take first pseudobulk + else: + if SIGNATURE_TYPE == "pre_encoded": pseudobulk = ( adata_sampled.layers[count_key].mean(axis=0).reshape(1, -1) - ) # .astype(int).astype(numpy.float32) - adata_pseudobulk = create_anndata_pseudobulk( - adata_sampled, pseudobulk ) - result = model.get_latent_representation(adata_pseudobulk).reshape( - -1 + adata_sampled = create_anndata_pseudobulk( + adata_sampled, pseudobulk ) - else: - raise ValueError( - "Only pre-encoded signatures are supported for now." - ) + result = model.get_latent_representation(adata_sampled) + if SIGNATURE_TYPE == "pre_encoded": + result = result.reshape(-1) + elif SIGNATURE_TYPE == "post_inference": + result = result.mean(axis=0) repeat_list.append(repeat) representation_list.append(result) cell_type_list.append(cell_type) diff --git a/benchmark_utils/plotting_utils.py b/benchmark_utils/plotting_utils.py index caecacfaff..340736fde0 100644 --- a/benchmark_utils/plotting_utils.py +++ b/benchmark_utils/plotting_utils.py @@ -1,9 +1,12 @@ +import os import matplotlib.pyplot as plt import seaborn as sns import numpy as np import pandas as pd from typing import Dict +from datetime import datetime +from loguru import logger def plot_purified_deconv_results(deconv_results, only_fit_one_baseline, more_details=False, save=False, filename="test"): """Plot the deconv results from sanity check 1""" @@ -106,7 +109,14 @@ def plot_deconv_lineplot(results: Dict[int, pd.DataFrame], plt.show() if save: - plt.savefig(f"/home/owkin/project/plots/{filename}.png", dpi=300) + path = f"/home/owkin/project/plots/{filename}.png" + if os.path.isfile(path): + new_path = f"/home/owkin/project/plots/{filename}_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.png" + logger.warning(f"{path} already exists. Saving file on this path instead: {new_path}") + path = new_path + plt.savefig(path, dpi=300) + logger.info(f"Plot saved to the following path: {path}") + def plot_metrics(model_history, train: bool = True, n_epochs: int = 100): """Plot the train or val metrics from training.""" @@ -241,6 +251,12 @@ def compare_tuning_results( custom_palette = sns.color_palette("husl", n_colors=len(all_results[variable_tuned].unique())) all_results["epoch"] = all_results.index + if (n_nan := all_results[variable_to_plot].isna().sum()) > 0: + print( + f"There are {n_nan} missing values in the variable to plot ({variable_to_plot})." + "Filling them with the next row values." + ) + all_results[variable_to_plot] = all_results[variable_to_plot].fillna(method='bfill') sns.set_theme(style="darkgrid") sns.lineplot(x="epoch", y=variable_to_plot, hue=variable_tuned, ci="sd", data=all_results, err_style="bars", palette=custom_palette) plt.show() diff --git a/benchmark_utils/sanity_checks_utils.py b/benchmark_utils/sanity_checks_utils.py index 1682823aeb..434bc672d9 100644 --- a/benchmark_utils/sanity_checks_utils.py +++ b/benchmark_utils/sanity_checks_utils.py @@ -32,129 +32,6 @@ def melt_df(deconv_results): ].copy() return deconv_results_melted_methods_temp -def run_purified_sanity_check( - adata_train: ad.AnnData, - adata_pseudobulk_test_counts: ad.AnnData, - adata_pseudobulk_test_rc: ad.AnnData, - signature: pd.DataFrame, - intersection: List[str], - generative_models : Dict[str, Union[scvi.model.SCVI, - scvi.model.CondSCVI, - scvi.model.DestVI, - scvi.model.MixUpVI]], - baselines: List[str], -): - """Run sanity check 1 on purified cell types. - - Sanity check 1 is an "easy" deconvolution task where the pseudobulk test dataset - is composed of purified cell types. Thus the groundtruth proportion is 1 for each - sample in the dataset. - - If the `generative_models` dictionnary is empty, only the baselines will be run. - - Parameters - ---------- - adata_train: ad.AnnData - scRNAseq training dataset. - adata_pseudobulk_test_counts: ad.AnnData - pseudobulk RNA seq test dataset (counts). - adata_pseudobulk_test_rc: ad.AnnData - pseudobulk RNA seq test dataset (relative counts). - signature: pd.DataFrame - Signature matrix. - intersection: List[str] - List of genes in common between the signature and the test dataset. - generative_models: Dict[str, scvi.model] - Dictionnary of generative models. - baselines: List[str] - List of baseline methods to run. - - Returns - pd.DataFrame - Melted dataframe of the deconvolution results. - """ - logger.info("Running sanity check...") - - # 1. Baselines - deconv_results_melted_methods = pd.DataFrame(columns=["Cell type predicted", "Cell type", "Estimated Fraction", "Method"]) - ## NNLS - if "nnls" in baselines: - deconv_results = perform_nnls(signature, adata_pseudobulk_test_rc[:, intersection]) - deconv_results_melted_methods_tmp = melt_df(deconv_results) - deconv_results_melted_methods_tmp["Method"] = "nnls" - deconv_results_melted_methods = pd.concat( - [deconv_results_melted_methods, deconv_results_melted_methods_tmp] - ) - - # Pseudobulk Dataframe for TAPE and Scaden - pseudobulk_test_df = pd.DataFrame( - adata_pseudobulk_test_rc[:, intersection].X, - index=adata_pseudobulk_test_rc.obs_names, - columns=intersection, - ) - ## TAPE - if "TAPE" in baselines: - _, deconv_results = \ - Deconvolution(signature.T, pseudobulk_test_df, - sep='\t', scaler='mms', - datatype='counts', genelenfile=None, - mode='overall', adaptive=True, variance_threshold=0.98, - save_model_name=None, - batch_size=128, epochs=128, seed=1) - deconv_results_melted_methods_tmp = melt_df(deconv_results) - deconv_results_melted_methods_tmp["Method"] = "TAPE" - deconv_results_melted_methods = pd.concat( - [deconv_results_melted_methods, deconv_results_melted_methods_tmp] - ) - ## Scaden - if "Scaden" in baselines: - deconv_results = ScadenDeconvolution(signature.T, - pseudobulk_test_df, - sep='\t', - batch_size=128, epochs=128) - deconv_results_melted_methods_tmp = melt_df(deconv_results) - deconv_results_melted_methods_tmp["Method"] = "Scaden" - deconv_results_melted_methods = pd.concat( - [deconv_results_melted_methods, deconv_results_melted_methods_tmp] - ) - - if generative_models == {}: - return deconv_results_melted_methods - - ### 2. Generative models - for model in generative_models.keys(): - if model == "DestVI": - continue - # DestVI is not used for Sanity check 1 - not enough - # samples to fit the stLVM. - # deconv_results = generative_models[model].get_proportions(adata_pseudobulk_test) - # deconv_results = deconv_results.drop(["noise_term"], - # axis=1, - # inplace=True) - # deconv_results_melted_methods_tmp = melt_df(deconv_results) - # deconv_results_melted_methods_tmp["Method"] = model - # deconv_results_melted_methods = pd.concat( - # [deconv_results_melted_methods, deconv_results_melted_methods_tmp] - # ) - else: - adata_latent_signature = create_latent_signature( - adata=adata_train, - model=generative_models[model], - average_all_cells = True, - sc_per_pseudobulk=3000, - ) - deconv_results = perform_latent_deconv( - adata_pseudobulk=adata_pseudobulk_test_counts, - adata_latent_signature=adata_latent_signature, - model=generative_models[model], - ) - deconv_results_melted_methods_tmp = melt_df(deconv_results) - deconv_results_melted_methods_tmp["Method"] = model - deconv_results_melted_methods = pd.concat( - [deconv_results_melted_methods, deconv_results_melted_methods_tmp] - ) - return deconv_results_melted_methods - def run_sanity_check( adata_train: ad.AnnData, @@ -163,7 +40,6 @@ def run_sanity_check( all_adata_samples_test: List[ad.AnnData], df_proportions_test: pd.DataFrame, signature: pd.DataFrame, - intersection: List[str], generative_models : Dict[str, Union[scvi.model.SCVI, scvi.model.CondSCVI, scvi.model.DestVI, @@ -185,8 +61,6 @@ def run_sanity_check( pseudobulk RNA seq test dataset (relative counts). signature: pd.DataFrame Signature matrix. - intersection: List[str] - List of genes in common between the signature and the test dataset. generative_models: Dict[str, scvi.model] Dictionnary of generative models. baselines: List[str] @@ -211,12 +85,16 @@ def run_sanity_check( # 1. Linear regression (NNLS) if "nnls" in baselines: deconv_results = perform_nnls(signature, - adata_pseudobulk_test_rc[:, intersection]) + adata_pseudobulk_test_rc[:, signature.index]) correlations = compute_correlations(deconv_results, df_proportions_test) group_correlations = compute_group_correlations(deconv_results, df_proportions_test) df_test_correlations.loc[:, "nnls"] = correlations.values df_test_group_correlations.loc[:, "nnls"] = group_correlations.values + # get all genes for wich adata_train.var["highly_variable"] is True + filtered_genes = adata_train.var.index[adata_train.var["highly_variable"]].tolist() + intersection = list(set(signature.index).intersection(set(filtered_genes))) + pseudobulk_test_df = pd.DataFrame( adata_pseudobulk_test_rc[:, intersection].X, index=adata_pseudobulk_test_rc.obs_names, @@ -225,7 +103,7 @@ def run_sanity_check( # 2. TAPE if "TAPE" in baselines: _, deconv_results = \ - Deconvolution(signature.T, pseudobulk_test_df, + Deconvolution(signature.loc[intersection].T, pseudobulk_test_df, sep='\t', scaler='mms', datatype='counts', genelenfile=None, mode='overall', adaptive=True, variance_threshold=0.98, @@ -237,7 +115,7 @@ def run_sanity_check( df_test_group_correlations.loc[:, "TAPE"] = group_correlations.values ## 3. Scaden if "Scaden" in baselines: - deconv_results = ScadenDeconvolution(signature.T, + deconv_results = ScadenDeconvolution(signature.loc[intersection].T, pseudobulk_test_df, sep='\t', batch_size=128, epochs=128) @@ -265,16 +143,19 @@ def run_sanity_check( if model == "MixupVI": use_mixupvi=True adata_latent_signature = create_latent_signature( - adata=adata_train, + adata=adata_train[:,filtered_genes], model=generative_models[model], - use_mixupvi=use_mixupvi, + use_mixupvi=False, # should be equal to use_mixupvi, but if True, + # then it averages as many cells as self.n_cells_per-pseudobulk from mixupvae + # (and not the number we wish in the benchmark) average_all_cells = True, sc_per_pseudobulk=10000, ) deconv_results = perform_latent_deconv( - adata_pseudobulk=adata_pseudobulk_test_counts, + adata_pseudobulk=adata_pseudobulk_test_counts[:,filtered_genes], all_adata_samples=all_adata_samples_test, - use_mixupvi=use_mixupvi, + filtered_genes=filtered_genes, + use_mixupvi=False, # see comment above adata_latent_signature=adata_latent_signature, model=generative_models[model], ) @@ -284,3 +165,128 @@ def run_sanity_check( df_test_group_correlations.loc[:, model] = group_correlations.values return df_test_correlations, df_test_group_correlations + + + +## NOT USED +# def run_purified_sanity_check( +# adata_train: ad.AnnData, +# adata_pseudobulk_test_counts: ad.AnnData, +# adata_pseudobulk_test_rc: ad.AnnData, +# filtered_genes: list, +# signature: pd.DataFrame, +# generative_models : Dict[str, Union[scvi.model.SCVI, +# scvi.model.CondSCVI, +# scvi.model.DestVI, +# scvi.model.MixUpVI]], +# baselines: List[str], +# ): +# """Run sanity check 1 on purified cell types. + +# Sanity check 1 is an "easy" deconvolution task where the pseudobulk test dataset +# is composed of purified cell types. Thus the groundtruth proportion is 1 for each +# sample in the dataset. + +# If the `generative_models` dictionnary is empty, only the baselines will be run. + +# Parameters +# ---------- +# adata_train: ad.AnnData +# scRNAseq training dataset. +# adata_pseudobulk_test_counts: ad.AnnData +# pseudobulk RNA seq test dataset (counts). +# adata_pseudobulk_test_rc: ad.AnnData +# pseudobulk RNA seq test dataset (relative counts). +# signature: pd.DataFrame +# Signature matrix. +# generative_models: Dict[str, scvi.model] +# Dictionnary of generative models. +# baselines: List[str] +# List of baseline methods to run. + +# Returns +# pd.DataFrame +# Melted dataframe of the deconvolution results. +# """ +# logger.info("Running sanity check...") + +# # 1. Baselines +# deconv_results_melted_methods = pd.DataFrame(columns=["Cell type predicted", "Cell type", "Estimated Fraction", "Method"]) +# ## NNLS +# if "nnls" in baselines: +# deconv_results = perform_nnls(signature, adata_pseudobulk_test_rc[:, signature.index]) +# deconv_results_melted_methods_tmp = melt_df(deconv_results) +# deconv_results_melted_methods_tmp["Method"] = "nnls" +# deconv_results_melted_methods = pd.concat( +# [deconv_results_melted_methods, deconv_results_melted_methods_tmp] +# ) + +# # Pseudobulk Dataframe for TAPE and Scaden +# intersection = set(signature.index).intersection(set(filtered_genes)) +# pseudobulk_test_df = pd.DataFrame( +# adata_pseudobulk_test_rc[:, intersection].X, +# index=adata_pseudobulk_test_rc.obs_names, +# columns=intersection, +# ) +# ## TAPE +# if "TAPE" in baselines: +# _, deconv_results = \ +# Deconvolution(signature.T, pseudobulk_test_df, +# sep='\t', scaler='mms', +# datatype='counts', genelenfile=None, +# mode='overall', adaptive=True, variance_threshold=0.98, +# save_model_name=None, +# batch_size=128, epochs=128, seed=1) +# deconv_results_melted_methods_tmp = melt_df(deconv_results) +# deconv_results_melted_methods_tmp["Method"] = "TAPE" +# deconv_results_melted_methods = pd.concat( +# [deconv_results_melted_methods, deconv_results_melted_methods_tmp] +# ) +# ## Scaden +# if "Scaden" in baselines: +# deconv_results = ScadenDeconvolution(signature.T, +# pseudobulk_test_df, +# sep='\t', +# batch_size=128, epochs=128) +# deconv_results_melted_methods_tmp = melt_df(deconv_results) +# deconv_results_melted_methods_tmp["Method"] = "Scaden" +# deconv_results_melted_methods = pd.concat( +# [deconv_results_melted_methods, deconv_results_melted_methods_tmp] +# ) + +# if generative_models == {}: +# return deconv_results_melted_methods + +# ### 2. Generative models +# for model in generative_models.keys(): +# if model == "DestVI": +# continue +# # DestVI is not used for Sanity check 1 - not enough +# # samples to fit the stLVM. +# # deconv_results = generative_models[model].get_proportions(adata_pseudobulk_test) +# # deconv_results = deconv_results.drop(["noise_term"], +# # axis=1, +# # inplace=True) +# # deconv_results_melted_methods_tmp = melt_df(deconv_results) +# # deconv_results_melted_methods_tmp["Method"] = model +# # deconv_results_melted_methods = pd.concat( +# # [deconv_results_melted_methods, deconv_results_melted_methods_tmp] +# # ) +# else: +# adata_latent_signature = create_latent_signature( +# adata=adata_train[:,filtered_genes], +# model=generative_models[model], +# average_all_cells = True, +# sc_per_pseudobulk=3000, +# ) +# deconv_results = perform_latent_deconv( +# adata_pseudobulk=adata_pseudobulk_test_counts[:,filtered_genes], +# adata_latent_signature=adata_latent_signature, +# model=generative_models[model], +# ) +# deconv_results_melted_methods_tmp = melt_df(deconv_results) +# deconv_results_melted_methods_tmp["Method"] = model +# deconv_results_melted_methods = pd.concat( +# [deconv_results_melted_methods, deconv_results_melted_methods_tmp] +# ) +# return deconv_results_melted_methods diff --git a/benchmark_utils/signature_utils.py b/benchmark_utils/signature_utils.py index 6dad54e394..2a1039d7b3 100644 --- a/benchmark_utils/signature_utils.py +++ b/benchmark_utils/signature_utils.py @@ -6,29 +6,33 @@ def create_signature( - adata: ad.AnnData, signature_type: str = "crosstissue_general", ): """Create the signature matrix from the single cell dataset.""" if signature_type == "laughney": - signature = pd.read_csv( - "/home/owkin/project/laughney_signature.csv", index_col=0 - ).drop(["Endothelial", "Malignant", "Stroma", "Epithelial"], axis=1) - # map the HGNC notation to ENSG if the signature matrix uses HGNC notation - mg = mygene.MyGeneInfo() - genes = mg.querymany( - signature.index, - scopes="symbol", - fields=["ensembl"], - species="human", - verbose=False, - as_dataframe=True, + raise NotImplementedError( + "Laughney signature not available now. To solve, upload it directly with " + "ENSG names." ) - ensg_names = map_hgnc_to_ensg(genes, adata) - signature.index = ensg_names + # signature = pd.read_csv( + # "/home/owkin/project/laughney_signature.csv", index_col=0 + # ).drop(["Endothelial", "Malignant", "Stroma", "Epithelial"], axis=1) + # # map the HGNC notation to ENSG if the signature matrix uses HGNC notation + # mg = mygene.MyGeneInfo() + # genes = mg.querymany( + # signature.index, + # scopes="symbol", + # fields=["ensembl"], + # species="human", + # verbose=False, + # as_dataframe=True, + # ) + # ensg_names = map_hgnc_to_ensg(genes, adata) + # signature.index = ensg_names elif signature_type == "CTI_1st_level_granularity": signature = read_txt_r_signature( "/home/owkin/project/Almudena/Output/Crosstiss_Immune_norm/CTI.txt" + # "/home/owkin/project/Almudena/Output/Crosstiss_Immune/CTI.txt" ) # it is the normalised one (using adata.X and not adata.raw.X) elif signature_type == "CTI_2nd_level_granularity": signature = read_txt_r_signature( @@ -46,10 +50,7 @@ def create_signature( signature = read_txt_r_signature( "/home/owkin/project/Simon/signature_FACS_1st_level_granularity/FACS_1st_level_granularity_ensg.txt" ) - # intersection between all genes and marker genes - intersection = list(set(adata.var_names).intersection(signature.index)) - signature = signature.loc[intersection] - return signature, intersection + return signature def read_txt_r_signature(path): diff --git a/benchmark_utils/training_utils.py b/benchmark_utils/training_utils.py index 04c6c8c787..efc477e587 100644 --- a/benchmark_utils/training_utils.py +++ b/benchmark_utils/training_utils.py @@ -15,6 +15,7 @@ MAX_EPOCHS, BATCH_SIZE, LATENT_SIZE, + N_HIDDEN, N_PSEUDOBULKS, N_CELLS_PER_PSEUDOBULK, TRAIN_SIZE, @@ -29,8 +30,6 @@ MIXUP_PENALTY, DISPERSION, GENE_LIKELIHOOD, - MIXUP_PENATLY_AGGREGATION, - AVERAGE_VARIABLES_MIXUP_PENALTY, SEED, ) @@ -60,11 +59,12 @@ def tune_mixupvi(adata: ad.AnnData, search_space=search_space, num_samples=num_samples, # will randomly num_samples samples (with replacement) among the HP cominations specified max_epochs=MAX_EPOCHS, - resources={"cpu": 10, "gpu": 0.5}, + resources={"cpu": 6, "gpu": 1}, ) all_results, best_hp, tuning_path, search_path = format_and_save_tuning_results( tuning_results, variables=TUNED_VARIABLES, training_dataset=training_dataset, + cat_cov=CAT_COV, cont_cov=CONT_COV, ) return all_results, best_hp, tuning_path, search_path @@ -95,6 +95,7 @@ def fit_mixupvi(adata: ad.AnnData, n_pseudobulks=N_PSEUDOBULKS, n_cells_per_pseudobulk=N_CELLS_PER_PSEUDOBULK, n_latent=LATENT_SIZE, + n_hidden=N_HIDDEN, use_batch_norm=USE_BATCH_NORM, signature_type=SIGNATURE_TYPE, loss_computation=LOSS_COMPUTATION, @@ -103,8 +104,6 @@ def fit_mixupvi(adata: ad.AnnData, mixup_penalty=MIXUP_PENALTY, dispersion=DISPERSION, gene_likelihood=GENE_LIKELIHOOD, - mixup_penalty_aggregation=MIXUP_PENATLY_AGGREGATION, - average_variables_mixup_penalty=AVERAGE_VARIABLES_MIXUP_PENALTY, ) mixupvi_model.view_anndata_setup() mixupvi_model.train( diff --git a/benchmark_utils/tuning_utils.py b/benchmark_utils/tuning_utils.py index b53231e791..c1e77e646a 100644 --- a/benchmark_utils/tuning_utils.py +++ b/benchmark_utils/tuning_utils.py @@ -1,11 +1,17 @@ +"""Tuning utils file.""" + import json from collections import defaultdict import numpy as np import pandas as pd import os import pickle +from tuning_configs import TUNED_VARIABLES, SEARCH_SPACE, METRIC, ADDITIONAL_METRICS +from constants import TRAINING_DATASET -def format_and_save_tuning_results(tuning_results, variables: str, training_dataset : str): +def format_and_save_tuning_results( + tuning_results, variables: str, training_dataset : str, cat_cov : list, cont_cov : list, +): """Format the tuning results and save them in the project directory.""" # format the results of all experiments keys = list(tuning_results.results[0].metrics.keys()) @@ -61,6 +67,8 @@ def format_and_save_tuning_results(tuning_results, variables: str, training_data all_results.to_csv(tuning_path) search_space = tuning_results.search_space + search_space["cat_cov"] = cat_cov + search_space["cont_cov"] = cont_cov search_space["best_hp"] = best_hp with open(search_path, "wb") as ff: pickle.dump(search_space, ff) @@ -77,4 +85,62 @@ def read_search_space(search_path): search_space = pickle.load(ff) return search_space - \ No newline at end of file + +def format_and_save_tuning_results_backup(ray_directory: str = "tune_mixupvi_2024-04-08-08:55:24"): + """This function essentially does the same as format_and_save_tuning_results. + + But this one should be used in a handcrafted manner (by providing the ray directory + saved locally) when for some reason, tuning results were successfully saved locally + by ray, but not formatted and saved in the shared /project folder. + + Five global variables are used here and should be specified accordingly in the + tuning and constants config files : TUNED_VARIABLES, SEARCH_SPACE, TRAINING_DATASET, + METRIC, ADDITIONAL_METRICS + """ + directory = f"/home/owkin/deepdeconv-fork/ray/{ray_directory}/" + all_metrics = [METRIC] + ADDITIONAL_METRICS # all metric columns we want to retrieve + + all_results = [] + for path in os.listdir(directory): # loop through every result of hyperparameters tried + if path.startswith("_trainable"): + path = directory + path + results = defaultdict(list) + with open(path+"/result.json", "r") as ff: + for line in ff: + # loop through every epoch of the training + data = json.loads(line.strip()) + for key in all_metrics: + if key in data: + results[key].append(data[key]) + else: + results[key].append(np.nan) + results = pd.DataFrame(results) + + hyperparameters = path.split("/")[-1] + for i, variable in enumerate(sorted(TUNED_VARIABLES)): + hyperparameters=hyperparameters.split(f"{variable}=")[1] + if i < len(TUNED_VARIABLES)-1: + value = hyperparameters.split(",")[0] + else: + value = hyperparameters.split("-")[0][:-5] + results[variable] = value + + all_results.append(results) + + all_results = pd.concat(all_results) + # save results and search space + save_dir = f"/home/owkin/project/mixupvi_tuning/{'-'.join(TUNED_VARIABLES)}/" + new_path = save_dir + f"{TRAINING_DATASET}_dataset_{ray_directory}" + if not os.path.exists(save_dir): + # create a directory for the variable tuned + os.makedirs(save_dir) + if not os.path.exists(new_path): + # create a directory for the specific grid search performed + os.makedirs(new_path) + tuning_path = f"{new_path}/tuning_results.csv" + search_path = f"{new_path}/search_space.pkl" + all_results.to_csv(tuning_path) + + search_space = SEARCH_SPACE + with open(search_path, "wb") as ff: + pickle.dump(search_space, ff) diff --git a/constants.py b/constants.py index 494202c240..f06271c025 100644 --- a/constants.py +++ b/constants.py @@ -1,14 +1,14 @@ """Constants and global variables to run the different deconv files.""" -## constants for run_mixupvi.py +## Constants for run_mixupvi.py TUNE_MIXUPVI = True -TRAINING_DATASET = "CTI_PROCESSED" # ["CTI", "TOY", "CTI_PROCESSED", "CTI_RAW"] +TRAINING_DATASET = "CTI" # ["CTI", "TOY", "CTI_PROCESSED", "CTI_RAW"] TRAINING_CELL_TYPE_GROUP = ( "2nd_level_granularity" # ["1st_level_granularity", "2nd_level_granularity", "3rd_level_granularity", "4th_level_granularity", "FACS_1st_level_granularity"] ) -## constants for run_pseudobulk_benchmark.py -SIGNATURE_CHOICE = "CTI_2nd_level_granularity" # ["laughney", "CTI_1st_level_granularity", "CTI_2nd_level_granularity", "CTI_3rd_level_granularity", "CTI_4th_level_granularity", "FACS_1st_level_granularity"] +## Constants for run_pseudobulk_benchmark.py +SIGNATURE_CHOICE = "CTI_1st_level_granularity" # ["laughney", "CTI_1st_level_granularity", "CTI_2nd_level_granularity", "CTI_3rd_level_granularity", "CTI_4th_level_granularity", "FACS_1st_level_granularity"] if SIGNATURE_CHOICE in ["laughney", "CTI_1st_level_granularity"]: BENCHMARK_CELL_TYPE_GROUP = "1st_level_granularity" elif SIGNATURE_CHOICE == "CTI_2nd_level_granularity": @@ -21,28 +21,32 @@ BENCHMARK_CELL_TYPE_GROUP = "FACS_1st_level_granularity" else: BENCHMARK_CELL_TYPE_GROUP = None # no signature was created -BENCHMARK_DATASET = "CTI" # ["CTI", "TOY", "CTI_PROCESSED", "CTI_RAW"] +BENCHMARK_DATASET = "CTI" # ["CTI", "TOY", "CTI_RAW"] +BATCH_KEY = "donor_id" N_SAMPLES = 500 # number of pseudbulk samples to create and assess for deconvolution +N_CELLS = [2000] # list of number of cells to try for the lineplot GENERATIVE_MODELS = ["MixupVI"] #, "DestVI"] # "scVI", "CondscVI", "DestVI" -# GENERATIVE_MODELS = [] # if only want baselines BASELINES = ["nnls"] # "nnls", "TAPE", "Scaden" -# BASELINES = ["nnls"] # if only want nnls +COMPUTE_SC_RESULTS_WHEN_FACS = True -## general mixupvi constants when training it or preprocessing data +## General constants to change depending on the task SAVE_MODEL = False -SEED = 0 -N_GENES = 3000 # number of input genes after preprocessing -# MixUpVI training hyperparameters +SEED = 3 +LATENT_SIZE = 10 MAX_EPOCHS = 100 + +## Other constants to tune and then fix +N_GENES = 2000 # number of input genes after preprocessing +# MixUpVI training hyperparameters BATCH_SIZE = 2048 TRAIN_SIZE = 0.7 # as opposed to validation CHECK_VAL_EVERY_N_EPOCH = None if TRAIN_SIZE < 1: CHECK_VAL_EVERY_N_EPOCH = 1 # MixUpVI model hyperparameters -N_PSEUDOBULKS = 1 -N_CELLS_PER_PSEUDOBULK = None # None (then will be batch size) or int (will cap at batch size) -LATENT_SIZE = 30 +N_PSEUDOBULKS = 100 +N_CELLS_PER_PSEUDOBULK = 512 # None (then will be batch size) or int (will cap at batch size) +N_HIDDEN = 512 CONT_COV = None # None or list of continuous covariates to include CAT_COV = None # None or ["donor_id", "assay"] ENCODE_COVARIATES = False # whether to encode cont/cat covars (they are always decoded) @@ -52,8 +56,6 @@ MIXUP_PENALTY = "l2" # ["l2", "kl"] DISPERSION = "gene" # ["gene", "gene_label"] GENE_LIKELIHOOD = "zinb" # ["zinb", "nb", "poisson"] -MIXUP_PENATLY_AGGREGATION = "mean" # ["mean", "sum", "max"] -AVERAGE_VARIABLES_MIXUP_PENALTY = False USE_BATCH_NORM = "none" # ["encoder", "decoder", "none", "both"] # different possibilities of cell groupings with the CTI dataset @@ -133,7 +135,7 @@ "MemB": [ "Memory B cells"], "Plasma": ["Plasma cells", "Plasmablasts"], "Mono": ["Classical monocytes", "Nonclassical monocytes"], - "Macro":["Alveolar macrophages","Erythrophagocytic macrophages", + "Macro":["Alveolar macrophages","Erythrophagocytic macrophages", "Intermediate macrophages", "Intestinal macrophages"], "Naive_CD8T": [ "Tnaive/CM_CD8"], "Mem_CD8T": ["Tem/emra_CD8", "Trm/em_CD8", "Trm_gut_CD8"], @@ -145,8 +147,8 @@ "DC": ["DC1", "DC2", "migDC"], "pDC": ["pDC"], "Mast": ["Mast cells"], - "To remove": ["ABCs", "GC_B (I)", "GC_B (II)","Cycling", "T/B doublets", - "Cycling T&NK", "MNP/B doublets", "MNP/T doublets", "ILC3", + "To remove": ["ABCs", "GC_B (I)", "GC_B (II)","Cycling", "T/B doublets", + "Cycling T&NK", "MNP/B doublets", "MNP/T doublets", "ILC3", "MAIT","T_CD4/CD8", "Erythroid", "Megakaryocytes", "Progenitor"], }, @@ -173,23 +175,24 @@ "migDC": [ "migDC"], "pDC": ["pDC"], "Mast": ["Mast cells"], - "To remove": ["ABCs", "GC_B (I)", "GC_B (II)","Cycling", "T/B doublets", - "Cycling T&NK", "MNP/B doublets", "MNP/T doublets", "ILC3", + "To remove": ["ABCs", "GC_B (I)", "GC_B (II)","Cycling", "T/B doublets", + "Cycling T&NK", "MNP/B doublets", "MNP/T doublets", "ILC3", "MAIT","T_CD4/CD8", "Erythroid", "Megakaryocytes", "Progenitor"], }, "FACS_1st_level_granularity": { "B": ["Pre-B", "Pro-B", "Naive B cells","Memory B cells","Plasma cells"], - "NK": ["NK_CD16+", "NK_CD56bright_CD16-"], - "T": [ "Tnaive/CM_CD8","Tem/emra_CD8", "Trm/em_CD8", "Trm_gut_CD8","Tfh", - "Tnaive/CM_CD4", "Tnaive/CM_CD4_activated", "Teffector/EM_CD4", - "Trm_Th1/Th17","Tregs","T_CD4/CD8","Tgd_CRTAM+", "Trm_Tgd","MAIT"], - "Mono": ["Classical monocytes", "Nonclassical monocytes"], + "NK": ["NK_CD16+", "NK_CD56bright_CD16-"], + "T": [ "Tnaive/CM_CD8","Tem/emra_CD8", "Trm/em_CD8", "Trm_gut_CD8","Tfh", + "Tnaive/CM_CD4", "Tnaive/CM_CD4_activated", "Teffector/EM_CD4", + "Trm_Th1/Th17","Tregs","T_CD4/CD8","Tgd_CRTAM+", "Trm_Tgd","MAIT"], + "Mono": ["Classical monocytes", "Nonclassical monocytes"], "DC": ["DC1", "DC2", "migDC", "pDC"], - "To remove":["Plasmablasts","ABCs", "GC_B (I)", "GC_B (II)","Cycling", + "To remove":["Plasmablasts","ABCs", "GC_B (I)", "GC_B (II)","Cycling", "T/B doublets", "Cycling T&NK", "MNP/B doublets", "MNP/T doublets", "ILC3", "Erythroid", "Megakaryocytes", "Progenitor", - "Alveolar macrophages","Erythrophagocytic macrophages", + "Alveolar macrophages","Erythrophagocytic macrophages", "Intermediate macrophages", "Intestinal macrophages","Mast cells"] - } } + +# %% diff --git a/run_mixupvi.py b/run_mixupvi.py index 54f1afa785..f0e693ea00 100644 --- a/run_mixupvi.py +++ b/run_mixupvi.py @@ -102,14 +102,15 @@ # %% Load model / results: Uncomment if not running previous cells # if TUNE_MIXUPVI: -# path = "/home/owkin/project/mixupvi_tuning/n_latent-seed/CTI_PROCESSED_dataset_tune_mixupvi_2024-02-21-11:25:28" +# path = "/home/owkin/project/mixupvi_tuning/n_latent-seed/CTI_dataset_tune_mixupvi_2024-06-07-18:30:37" # all_results = read_tuning_results(f"{path}/tuning_results.csv") # search_space = read_search_space(f"{path}/search_space.pkl") -# best_hp = search_space["best_hp"] -# model_history = all_results.copy() -# for variable in best_hp : -# # plots for the best hp found by tuning -# model_history = model_history.loc[model_history[variable] == best_hp[variable]] +# if "best_hp" in search_space: +# best_hp = search_space["best_hp"] +# model_history = all_results.copy() +# for variable in best_hp : +# # plots for the best hp found by tuning +# model_history = model_history.loc[model_history[variable] == best_hp[variable]] # else: # import torch # model = torch.load(f"{model_path}/model.pt") @@ -130,17 +131,19 @@ # %% Plots to compare HPs if TUNE_MIXUPVI: - n_epochs = len(model_history["train_loss_epoch"]) - hp_index_to_plot = None - # hp_index_to_plot = [1, 2, 3] # only these index (of the HPs tried) will be plotted, for clearer visualisation - - if len(best_hp) == 1 or (len(best_hp) == 2 and "seed" in best_hp): - tuned_variable = list(set(best_hp.keys()) - {"seed"})[0] + n_epochs = len(set(all_results["train_loss_epoch"].index)) + # hp_index_to_plot = None + hp_index_to_plot = [0,1] # only these index (of the HPs tried) will be plotted, for clearer visualisation + + tuned_hps = all_results.T.loc[["train" not in col and "validation" not in col for col in all_results.columns]].index + if len(tuned_hps) == 1 or (len(tuned_hps) == 2 and "seed" in tuned_hps): + variable_tuned = list(set(tuned_hps) - {"seed"})[0] + # variable_tuned = "seed" for variable_to_plot in all_results.columns: if "validation" in variable_to_plot: compare_tuning_results( - all_results, variable_to_plot=variable_to_plot, - variable_tuned=tuned_variable, n_epochs=n_epochs, + all_results.copy(), variable_to_plot=variable_to_plot, + variable_tuned=variable_tuned, n_epochs=n_epochs, hp_index_to_plot=hp_index_to_plot, ) else: diff --git a/run_pseudobulk_benchmark.py b/run_pseudobulk_benchmark.py index c9685295f6..f44f8d3a1f 100644 --- a/run_pseudobulk_benchmark.py +++ b/run_pseudobulk_benchmark.py @@ -1,8 +1,9 @@ """Pseudobulk benchmark.""" # %% import scanpy as sc -from loguru import logger +import os import warnings +from loguru import logger from constants import ( BENCHMARK_DATASET, @@ -13,21 +14,23 @@ N_SAMPLES, GENERATIVE_MODELS, BASELINES, + N_CELLS, + BATCH_KEY ) from benchmark_utils import ( preprocess_scrna, - create_purified_pseudobulk_dataset, - create_uniform_pseudobulk_dataset, + # create_purified_pseudobulk_dataset, + # create_uniform_pseudobulk_dataset, create_dirichlet_pseudobulk_dataset, fit_scvi, fit_destvi, fit_mixupvi, create_signature, add_cell_types_grouped, - run_purified_sanity_check, + # run_purified_sanity_check, run_sanity_check, - plot_purified_deconv_results, + # plot_purified_deconv_results, plot_deconv_results, plot_deconv_results_group, plot_deconv_lineplot, @@ -43,28 +46,30 @@ # adata = scvi.data.heart_cell_atlas_subsampled() # preprocess_scrna(adata, keep_genes=1200) elif BENCHMARK_DATASET == "CTI": - adata = sc.read("/home/owkin/project/cti/cti_adata.h5ad") - preprocess_scrna(adata, - keep_genes=N_GENES, - batch_key="donor_id") + cti_path = f"/home/owkin/data/cti_data/processed/cti_processed_{N_GENES}.h5ad" + if os.path.exists(cti_path): + adata = sc.read(f"/home/owkin/data/cti_data/processed/cti_processed_{N_GENES}.h5ad") + else: + adata = sc.read("/home/owkin/project/cti/cti_adata.h5ad") + adata = preprocess_scrna(adata, + keep_genes=N_GENES, + batch_key=BATCH_KEY) + adata.write(cti_path) + filtered_genes = adata.var.index[adata.var["highly_variable"]].tolist() elif BENCHMARK_DATASET == "CTI_RAW": warnings.warn("The raw data of this adata is on adata.raw.X, but the normalised " "adata.X will be used here") adata = sc.read("/home/owkin/data/cross-tissue/omics/raw/local.h5ad") - preprocess_scrna(adata, + adata = preprocess_scrna(adata, keep_genes=N_GENES, batch_key="donor_id", ) -elif BENCHMARK_DATASET == "CTI_PROCESSED": - # Load processed for speed-up (already filtered, normalised, etc.) - adata = sc.read(f"/home/owkin/data/cti_data/processed/cti_processed_{N_GENES}.h5ad") + #save adata + adata.write(f"/home/owkin/data/cti_data/processed/cti_processed_{N_GENES}.h5ad") # %% load signature logger.info(f"Loading signature matrix: {SIGNATURE_CHOICE} | {BENCHMARK_CELL_TYPE_GROUP}...") -signature, intersection = create_signature( - adata, - signature_type=SIGNATURE_CHOICE, -) +signature = create_signature(signature_type=SIGNATURE_CHOICE) # %% add cell types groups and split train/test adata, train_test_index = add_cell_types_grouped(adata, BENCHMARK_CELL_TYPE_GROUP) @@ -80,7 +85,7 @@ if "scVI" in GENERATIVE_MODELS: logger.info("Fit scVI ...") model_path = f"project/models/{BENCHMARK_DATASET}_scvi.pkl" - scvi_model = fit_scvi(adata_train, + scvi_model = fit_scvi(adata_train[:, filtered_genes].copy(), model_path, save_model=SAVE_MODEL) generative_models["scVI"] = scvi_model @@ -94,12 +99,12 @@ # ) # Dirichlet adata_pseudobulk_train_counts, adata_pseudobulk_train_rc, df_proportions_test = create_dirichlet_pseudobulk_dataset( - adata_train, prior_alphas = None, n_sample = N_SAMPLES, + adata_train.copy(), prior_alphas = None, n_sample = N_SAMPLES, ) model_path_1 = f"project/models/{BENCHMARK_DATASET}_condscvi.pkl" model_path_2 = f"project/models/{BENCHMARK_DATASET}_destvi.pkl" - condscvi_model , destvi_model= fit_destvi(adata_train, + condscvi_model , destvi_model= fit_destvi(adata_train[:, filtered_genes].copy(), adata_pseudobulk_train_counts, model_path_1, model_path_2, @@ -112,7 +117,7 @@ if "MixupVI" in GENERATIVE_MODELS: logger.info("Train mixupVI ...") model_path = f"project/models/{BENCHMARK_DATASET}_{BENCHMARK_CELL_TYPE_GROUP}_{N_GENES}_mixupvi.pkl" - mixupvi_model = fit_mixupvi(adata_train, + mixupvi_model = fit_mixupvi(adata_train[:, filtered_genes].copy(), model_path, cell_type_group="cell_types_grouped", save_model=SAVE_MODEL, @@ -121,14 +126,10 @@ # %% Sanity check 3 -#num_cells = [50, 100, 300, 500, 1000] - -num_cells = [2000] - results = {} results_group = {} -for n in num_cells: +for n in N_CELLS: logger.info(f"Pseudobulk simulation with {n} sampled cells ...") all_adata_samples_test, adata_pseudobulk_test_counts, adata_pseudobulk_test_rc, df_proportions_test = create_dirichlet_pseudobulk_dataset( adata_test, @@ -151,7 +152,6 @@ all_adata_samples_test=all_adata_samples_test, df_proportions_test=df_proportions_test, signature=signature, - intersection=intersection, generative_models=generative_models, baselines=BASELINES, ) @@ -163,15 +163,17 @@ if len(results) > 1: plot_deconv_lineplot(results, save=True, - filename=f"sim_pseudobulk_lineplot") + filename=f"lineplot_tuned_mixupvi_third_granularity_retry_normal") else: key = list(results.keys())[0] plot_deconv_results(results[key], save=True, - filename=f"sim_pseudobulk_{key}") + # filename=f"benchmark_{key}_cells_first_granularity") + filename="test_first_type") plot_deconv_results_group(results_group[key], save=True, - filename=f"sim_pseudobulk_{key}_per_celltype") + # filename=f"benchmark_{key}_cells_first_granularity_cell_type") + filename="test_first_type_cell_type") # %% (Optional) Sanity check 1. @@ -184,7 +186,6 @@ # adata_pseudobulk_test_counts=adata_pseudobulk_test_counts, # adata_pseudobulk_test_rc=adata_pseudobulk_test_rc, # signature=signature, -# intersection=intersection, # generative_models=generative_models, # baselines=BASELINES, # ) diff --git a/scvi/model/_mixupvi.py b/scvi/model/_mixupvi.py index 3a8e0c4ba6..25ba37a02f 100644 --- a/scvi/model/_mixupvi.py +++ b/scvi/model/_mixupvi.py @@ -52,7 +52,7 @@ def get_latent_representation( self, adata: Optional[AnnData] = None, indices: Optional[Sequence[int]] = None, - get_pseudobulk: bool = True, + get_pseudobulk: bool = False, give_mean: bool = True, mc_samples: int = 5000, batch_size: Optional[int] = None, diff --git a/scvi/module/_mixupvae.py b/scvi/module/_mixupvae.py index bbf14463ba..3794db19a4 100644 --- a/scvi/module/_mixupvae.py +++ b/scvi/module/_mixupvae.py @@ -18,7 +18,6 @@ from scvi.nn import Encoder from ._vae import VAE from ._utils import ( - run_incompatible_value_checks, get_mean_pearsonr_torch, compute_ground_truth_proportions, compute_signature, @@ -103,16 +102,6 @@ class MixUpVAE(VAE): Whether to concatenate covariates to expression in encoder mixup_penalty The loss to use to compare the average of encoded cell and the encoded pseudobulk. - mixup_penalty_aggregation - One of - - * ``'sum'`` - Sum the n_pseudobulk L2 losses - * ``'mean'`` - Average the n_pseudobulk L2 losses - * ``'max'`` - Take the max among the n_pseudobulk L2 losses - average_variables_mixup_penalty - Whether to average the mixup penalty across the different variables. When False, the values are just summed. - If the loss_computation is in the latent space, there are n_latent variables. - If inside the reconstructed space, there are n_input variables. deeply_inject_covariates Whether to concatenate covariates into output of hidden layers in encoder/decoder. This option only applies when `n_layers` > 1. The covariates are concatenated to the input of subsequent hidden layers. @@ -146,7 +135,7 @@ def __init__( n_input: int, n_batch: int = 0, n_labels: int = 0, - n_hidden: Tunable[int] = 128, + n_hidden: Tunable[int] = 512, n_latent: Tunable[int] = 10, n_layers: Tunable[int] = 1, seed: Tunable[int] = 0, @@ -177,8 +166,6 @@ def __init__( loss_computation: Tunable[str] = "latent_space", pseudo_bulk: Tunable[str] = "pre_encoded", mixup_penalty: Tunable[str] = "l2", - mixup_penalty_aggregation: Tunable[str] = "mean", - average_variables_mixup_penalty: Tunable[bool] = False, ): torch.manual_seed(seed) @@ -208,13 +195,12 @@ def __init__( extra_encoder_kwargs=extra_encoder_kwargs, extra_decoder_kwargs=extra_decoder_kwargs, ) - run_incompatible_value_checks( - pseudo_bulk=pseudo_bulk, - loss_computation=loss_computation, - use_batch_norm=use_batch_norm, - mixup_penalty=mixup_penalty, - gene_likelihood=gene_likelihood, - ) + if use_batch_norm != "none" and n_pseudobulks == 1: + raise ValueError( + "Batch normalization cannot be used when only one pseudobulk is " + "computed - it cannot be considered as a batch on which batch " + "normalization can be applied." + ) self.n_pseudobulks = n_pseudobulks self.n_cells_per_pseudobulk = n_cells_per_pseudobulk @@ -222,8 +208,6 @@ def __init__( self.loss_computation = loss_computation self.pseudo_bulk = pseudo_bulk self.mixup_penalty = mixup_penalty - self.mixup_penalty_aggregation = mixup_penalty_aggregation - self.average_variables_mixup_penalty = average_variables_mixup_penalty self.z_signature = None self.logger_messages = set() @@ -366,7 +350,7 @@ def inference( categorical_pseudobulk_input = [] categorical_signature_input = [] j=0 - for n_cat in self.z_encoder.encoder.n_cat_list: + for n_cat in self.decoder.px_decoder.n_cat_list: if n_cat > 0 : # if n_cat == 0 then no batch index was given, so skip it one_hot_cat_covs = one_hot(cat_covs[j], n_cat) @@ -653,11 +637,18 @@ def loss( cosine_deconv_results = [] mse_deconv_results = [] mae_deconv_results = [] + z_signature = inference_outputs["z_signature"] if self.z_signature is None else self.z_signature for i, pseudobulk in enumerate(pseudobulk_z.detach().cpu().numpy()): predicted_proportions = nnls( - self.z_signature.detach().cpu().numpy().T, + z_signature.detach().cpu().numpy().T, pseudobulk, )[0] + if self.z_signature is None: + # resize predicted_proportions to the right number of labels + full_predicted_proportions = np.zeros(self.n_labels) + for j, cell_type in enumerate(tensors["labels"].unique().detach().cpu()): + full_predicted_proportions[int(cell_type)] = predicted_proportions[j] + predicted_proportions = full_predicted_proportions if np.any(predicted_proportions): # if not all zeros, sum the predictions to 1 predicted_proportions = predicted_proportions / predicted_proportions.sum() @@ -727,8 +718,6 @@ def get_mix_up_loss(self, inference_outputs, generative_outputs): mean_single_cells = generative_outputs["px"].rate[pseudobulk_indices, :].mean(axis=1) pseudobulk = generative_outputs["px_pseudobulk"].rate mixup_penalty = torch.sum((pseudobulk - mean_single_cells) ** 2, axis=1) - if self.average_variables_mixup_penalty: - mixup_penalty /= mean_single_cells.shape[1] elif self.mixup_penalty == "kl": # kl of mean(cells) compared to reference pseudobulk if self.loss_computation == "latent_space": @@ -765,10 +754,7 @@ def get_mix_up_loss(self, inference_outputs, generative_outputs): mixup_penalty = kl( averaged_cells_distrib, pseudobulk_reference_distrib ).sum(dim=-1) - if self.mixup_penalty_aggregation == "max": - mixup_penalty = mixup_penalty.max() - else : - mixup_penalty = torch.sum(mixup_penalty) - if self.mixup_penalty_aggregation == "mean": - mixup_penalty /= mean_single_cells.shape[0] + + mixup_penalty = torch.sum(mixup_penalty) / mean_single_cells.shape[0] + return mixup_penalty diff --git a/scvi/module/_utils.py b/scvi/module/_utils.py index 482f8626d8..a1fc210c80 100644 --- a/scvi/module/_utils.py +++ b/scvi/module/_utils.py @@ -140,51 +140,4 @@ def get_mean_pearsonr_torch(x, y): r_num = (xm*ym).sum(dim=1) r_den = torch.norm(xm, p=2, dim=1) * torch.norm(ym, p=2, dim=1) r_val = r_num / r_den - return torch.mean(r_val) - - -def run_incompatible_value_checks( - pseudo_bulk, loss_computation, use_batch_norm, mixup_penalty, gene_likelihood -): - """Check the values of the categorical variables to run MixUpVI are compatible. - The first 4 checks will only be relevant when pseudobulk will not be computed both - in encoder and decoder (right now, computed in both). Until then, use_batch_norm - should be None. - """ - if ( - pseudo_bulk == "pre_encoded" - and loss_computation == "latent_space" - and use_batch_norm in ["encoder", "both"] - ): - raise ValueError( - "MixUpVI cannot use batch normalization there, as the batch size of pseudobulk is 1." - ) - elif ( - pseudo_bulk == "pre_encoded" - and loss_computation == "reconstructed_space" - and use_batch_norm != "none" - ): - raise ValueError( - "MixUpVI cannot use batch normalization there, as the batch size of pseudobulk is 1." - ) - elif pseudo_bulk == "post_inference" and loss_computation == "latent_space": - raise ValueError( - "Pseudo bulk needs to be pre-encoded to compute the MixUp loss in the latent space." - ) - elif ( - pseudo_bulk == "post_inference" - and loss_computation == "reconstructed_space" - and use_batch_norm in ["decoder", "both"] - ): - raise ValueError( - "MixUpVI cannot use batch normalization there, as the batch size of pseudobulk is 1." - ) - if ( - mixup_penalty == "kl" - and loss_computation != "latent_space" - and gene_likelihood == "zinb" - ): - raise NotImplementedError( - "The KL divergence between ZINB distributions for the MixUp loss is not " - "implemented." - ) \ No newline at end of file + return torch.mean(r_val) \ No newline at end of file diff --git a/tuning_configs.py b/tuning_configs.py index 0ed630c7ba..3d625b040e 100644 --- a/tuning_configs.py +++ b/tuning_configs.py @@ -1,3 +1,5 @@ +"""Hyperparameter search configs.""" + from ray import tune from constants import ( @@ -16,8 +18,6 @@ LATENT_SIZE, N_PSEUDOBULKS, N_CELLS_PER_PSEUDOBULK, - MIXUP_PENATLY_AGGREGATION, - AVERAGE_VARIABLES_MIXUP_PENALTY, SEED, ) @@ -32,6 +32,7 @@ repeat_with_several_seeds = { "seed": tune.grid_search( [0, 3, 8, 12, 23] + # [0,1] ) } example_with_several_seeds = { @@ -40,13 +41,10 @@ } latent_space_search_space = { "n_latent": tune.grid_search( - list(range(len(GROUPS[TRAINING_CELL_TYPE_GROUP]) - 1, 550, 20)) # from n cell types to n marker genes - ) -} -latent_space_search_space_precise = { - "n_latent": tune.grid_search( - list(range(len(GROUPS[TRAINING_CELL_TYPE_GROUP]) - 1, 60, 5)) # from n cell types to 60 - ) + # list(range(len(GROUPS[TRAINING_CELL_TYPE_GROUP]) - 1, 550, 20)) # from n cell types to n marker genes + [70, 100, 120, 150] + ), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } batch_size_search_space = { "batch_size": tune.grid_search( @@ -61,7 +59,8 @@ signature_type_search_space = { "signature_type": tune.grid_search( ["pre_encoded", "post_inference"] - ) + ), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } loss_computation_search_space = { "loss_computation": tune.grid_search( @@ -69,20 +68,31 @@ ) } gene_likelihood_search_space = { - "gene_likelihood": tune.grid_search(["zinb", "nb", "poisson"]) + "gene_likelihood": tune.grid_search(["zinb", "nb", "poisson"]), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } n_hidden_search_space = { - "n_hidden": tune.grid_search([128, 256, 512, 1024]) + "n_hidden": tune.grid_search([128, 256, 512, 1024]), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } n_layers_search_space = { - "n_layers": tune.grid_search([1, 2, 3]) + "n_layers": tune.grid_search([1, 2, 3]), + "seed": tune.grid_search([3, 8, 12, 23, 42]) } n_pseudobulks_search_space = { - "n_pseudobulks": tune.grid_search([1, 5, 10, 30, 50, 100]), + "n_pseudobulks": tune.grid_search([1, 100]), + "seed": tune.grid_search([3, 8, 12]) + # "seed": tune.grid_search([3, 8, 12, 23, 42]) +} +n_cells_per_pseudobulk_search_space = { + "n_cells_per_pseudobulk": tune.grid_search([100, 256, 512, 1024, 2048]), "seed": tune.grid_search([3, 8, 12]) - # "seed": tune.grid_search([3, 8, 12, 23, 42]) } -SEARCH_SPACE = n_pseudobulks_search_space +use_batch_norm_search_space = { + "use_batch_norm": tune.grid_search(["none", "encoder", "decoder", "both"]), + "seed": tune.grid_search([3, 8, 12, 23, 42]) +} +SEARCH_SPACE = latent_space_search_space TUNED_VARIABLES = list(SEARCH_SPACE.keys()) NUM_SAMPLES = 1 # will only perform once the gridsearch (useful to change if mix of grid and random search for instance) @@ -102,8 +112,6 @@ "n_latent": LATENT_SIZE, "n_pseudobulks": N_PSEUDOBULKS, "n_cells_per_pseudobulk": N_CELLS_PER_PSEUDOBULK, - "mixup_penalty_aggregation": MIXUP_PENATLY_AGGREGATION, - "average_variables_mixup_penalty": AVERAGE_VARIABLES_MIXUP_PENALTY, "seed": SEED, } for key in list(model_fixed_hps): @@ -136,4 +144,4 @@ "pearson_coeff_deconv_train", "mse_deconv_train", "mae_deconv_train", -] +] \ No newline at end of file