Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes related to sensitivity analysis #15

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
map_hgnc_to_ensg,
)
from .sanity_checks_utils import (
run_purified_sanity_check,
# run_purified_sanity_check,
run_sanity_check,
)
from .tuning_utils import(
Expand Down
17 changes: 14 additions & 3 deletions benchmark_utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,33 @@
def preprocess_scrna(
adata: ad.AnnData, keep_genes: int = 2000, batch_key: Optional[str] = None
):
"""Preprocess single-cell RNA data for deconvolution benchmarking."""
"""Preprocess single-cell RNA data for deconvolution benchmarking.

* in adata.X, the normalized log1p counts are saved
* in adata.layers["counts"], raw counts are saved
* in adata.layers["relative_counts"], the relative counts are saved
=> The highly variable genes can be found in adata.var["highly_variable"]

"""
sc.pp.filter_genes(adata, min_counts=3)
adata.layers["counts"] = adata.X.copy() # preserve counts, used for training
sc.pp.normalize_total(adata, target_sum=1e4)
adata.layers["relative_counts"] = adata.X.copy() # preserve counts, used for
adata.layers["relative_counts"] = adata.X.copy() # preserve counts, used for baselines
sc.pp.log1p(adata)
adata.raw = adata # freeze the state in `.raw`
sc.pp.highly_variable_genes(
adata,
n_top_genes=keep_genes,
subset=True,
layer="counts",
flavor="seurat_v3",
batch_key=batch_key,
subset=False,
inplace=True
)
#TODO: add the filtering / QC steps that they perform in Servier
# concat the result df to adata.var

return adata


def split_dataset(
Expand Down
5 changes: 3 additions & 2 deletions benchmark_utils/deconv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData,
scvi.model.MixUpVI,
scvi.model.CondSCVI]],
all_adata_samples,
filtered_genes,
use_mixupvi: bool = True,
use_nnls: bool = True,
use_softmax: bool = False) -> pd.DataFrame:
Expand Down Expand Up @@ -74,15 +75,15 @@ def perform_latent_deconv(adata_pseudobulk: ad.AnnData,
if use_mixupvi:
latent_pseudobulks=[]
for i in range(len(all_adata_samples)):
latent_pseudobulks.append(model.get_latent_representation(all_adata_samples[i], get_pseudobulk=True))
latent_pseudobulks.append(model.get_latent_representation(all_adata_samples[i,filtered_genes], get_pseudobulk=True))
khalilouardini marked this conversation as resolved.
Show resolved Hide resolved
latent_pseudobulk = np.concatenate(latent_pseudobulks, axis=0)
else:
adata_pseudobulk = ad.AnnData(X=adata_pseudobulk.layers["counts"],
obs=adata_pseudobulk.obs,
var=adata_pseudobulk.var)
adata_pseudobulk.layers["counts"] = adata_pseudobulk.X.copy()

latent_pseudobulk = model.get_latent_representation(adata_pseudobulk)
latent_pseudobulk = model.get_latent_representation(adata_pseudobulk, get_pseudobulk=False)

if use_nnls:
deconv = LinearRegression(positive=True).fit(adata_latent_signature.X.T,
Expand Down
50 changes: 25 additions & 25 deletions benchmark_utils/latent_signature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from .dataset_utils import create_anndata_pseudobulk
from constants import SIGNATURE_TYPE


def create_latent_signature(
Expand All @@ -16,7 +17,6 @@ def create_latent_signature(
repeats: int = 1,
average_all_cells: bool = True,
sc_per_pseudobulk: int = 3000,
signature_type: str = "pre-encoded",
cell_type_column: str = "cell_types_grouped",
count_key: Optional[str] = "counts",
representation_key: Optional[str] = "X_scvi",
Expand Down Expand Up @@ -101,34 +101,34 @@ def create_latent_signature(
)
adata_sampled = adata[sampled_cells]

if signature_type == "pre-encoded":
assert (
model is not None,
"If representing a purified pseudo bulk (aggregate before embedding",
"), must give a model",
)
assert (
count_key is not None
), "Must give a count key if aggregating before embedding."

if use_mixupvi:
result = model.get_latent_representation(
adata_sampled, get_pseudobulk=True
).reshape(-1)
else:
assert (
model is not None,
"If representing a purified pseudo bulk (aggregate before embedding",
"), must give a model",
)
assert (
count_key is not None
), "Must give a count key if aggregating before embedding."

if use_mixupvi:
# TODO: in this case, n_cells sampled will be equal to self.n_cells_per_pseudobulk by mixupvae
khalilouardini marked this conversation as resolved.
Show resolved Hide resolved
# so change that to being equal to either all cells (if average_all_cells) or sc_per_pseudobulk
result = model.get_latent_representation(
adata_sampled, get_pseudobulk=True
)[0] # take first pseudobulk
else:
if SIGNATURE_TYPE == "pre_encoded":
pseudobulk = (
adata_sampled.layers[count_key].mean(axis=0).reshape(1, -1)
) # .astype(int).astype(numpy.float32)
adata_pseudobulk = create_anndata_pseudobulk(
adata_sampled, pseudobulk
)
result = model.get_latent_representation(adata_pseudobulk).reshape(
-1
adata_sampled = create_anndata_pseudobulk(
adata_sampled, pseudobulk
)
else:
raise ValueError(
"Only pre-encoded signatures are supported for now."
)
result = model.get_latent_representation(adata_sampled)
if SIGNATURE_TYPE == "pre_encoded":
result = result.reshape(-1)
elif SIGNATURE_TYPE == "post_inference":
result = result.mean(axis=0)
repeat_list.append(repeat)
representation_list.append(result)
cell_type_list.append(cell_type)
Expand Down
18 changes: 17 additions & 1 deletion benchmark_utils/plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

from typing import Dict
from datetime import datetime
from loguru import logger

def plot_purified_deconv_results(deconv_results, only_fit_one_baseline, more_details=False, save=False, filename="test"):
"""Plot the deconv results from sanity check 1"""
Expand Down Expand Up @@ -106,7 +109,14 @@ def plot_deconv_lineplot(results: Dict[int, pd.DataFrame],
plt.show()

if save:
plt.savefig(f"/home/owkin/project/plots/{filename}.png", dpi=300)
path = f"/home/owkin/project/plots/{filename}.png"
if os.path.isfile(path):
new_path = f"/home/owkin/project/plots/{filename}_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.png"
logger.warning(f"{path} already exists. Saving file on this path instead: {new_path}")
path = new_path
plt.savefig(path, dpi=300)
logger.info(f"Plot saved to the following path: {path}")


def plot_metrics(model_history, train: bool = True, n_epochs: int = 100):
"""Plot the train or val metrics from training."""
Expand Down Expand Up @@ -241,6 +251,12 @@ def compare_tuning_results(

custom_palette = sns.color_palette("husl", n_colors=len(all_results[variable_tuned].unique()))
all_results["epoch"] = all_results.index
if (n_nan := all_results[variable_to_plot].isna().sum()) > 0:
print(
f"There are {n_nan} missing values in the variable to plot ({variable_to_plot})."
"Filling them with the next row values."
)
all_results[variable_to_plot] = all_results[variable_to_plot].fillna(method='bfill')
sns.set_theme(style="darkgrid")
sns.lineplot(x="epoch", y=variable_to_plot, hue=variable_tuned, ci="sd", data=all_results, err_style="bars", palette=custom_palette)
plt.show()
Loading