Skip to content

Commit

Permalink
adapt new functions
Browse files Browse the repository at this point in the history
  • Loading branch information
gaddams committed Nov 3, 2021
1 parent da6b5d8 commit de0d884
Show file tree
Hide file tree
Showing 5 changed files with 1,422 additions and 1,986 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
dependencies:
- python=3.8.5
- python>=3.8.5
- pip=20.2.2
- pytorch=1.4.0
- scipy=1.5.2
Expand Down
9 changes: 3 additions & 6 deletions tangram/mapping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,18 @@ def pp_adatas(adata_sc, adata_sp, genes=None):
)

# Calculate uniform density prior as 1/number_of_spots
rna_count_per_spot = adata_sp.X.sum(axis=1)
adata_sp.obs["uniform_density"] = np.ones(adata_sp.X.shape[0]) / adata_sp.X.shape[0]
logging.info(
f"uniform based density prior is calculated and saved in `obs``uniform_density` of the spatial Anndata."
)

# Calculate rna_count_based density prior as % of rna molecule count
rna_count_per_spot = adata_sp.X.sum(axis=1)
adata_sp.obs["rna_count_based_density"] = rna_count_per_spot / np.sum(
rna_count_per_spot
)
rna_count_per_spot = np.array(adata_sp.X.sum(axis=1)).squeeze()
adata_sp.obs["rna_count_based_density"] = rna_count_per_spot / np.sum(rna_count_per_spot)
logging.info(
f"rna count based density prior is calculated and saved in `obs``rna_count_based_density` of the spatial Anndata."
)


def adata_to_cluster_expression(adata, cluster_label, scale=True, add_density=True):
"""
Expand Down
59 changes: 49 additions & 10 deletions tangram/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,41 @@ def construct_obs_plot(df_plot, adata, perc=0, suffix=None):
adata.obs = pd.concat([adata.obs, df_plot], axis=1)


def plot_cell_annotation_sc(adata_sp, annotation_list, perc=0):

def plot_cell_annotation_sc(
adata_sp,
annotation_list,
x="x",
y="y",
spot_size=None,
scale_factor=None,
perc=0,
ax=None
):

# remove previous df_plot in obs
adata_sp.obs.drop(annotation_list, inplace=True, errors="ignore", axis=1)

# construct df_plot
df = adata_sp.obsm["tangram_ct_pred"][annotation_list]
construct_obs_plot(df, adata_sp, perc=perc)


#non visium data
if 'spatial' not in adata_sp.obsm.keys():
#add spatial coordinates to obsm of spatial data
coords = [[x,y] for x,y in zip(adata_sp.obs[x].values,adata_sp.obs[y].values)]
adata_sp.obsm['spatial'] = np.array(coords)

if 'spatial' not in adata_sp.uns.keys() and spot_size == None and scale_factor == None:
raise ValueError("Spot Size and Scale Factor cannot be None when ad_sp.uns['spatial'] does not exist")

#REVIEW
if 'spatial' in adata_sp.uns.keys() and spot_size != None and scale_factor != None:
raise ValueError("Spot Size and Scale Factor should be None when ad_sp.uns['spatial'] exists")

sc.pl.spatial(
adata_sp, color=annotation_list, cmap="viridis", show=False, frameon=False,
adata_sp, color=annotation_list, cmap="viridis", show=False, frameon=False, spot_size=spot_size, scale_factor=scale_factor, ax=ax
)

# remove df_plot in obs
adata_sp.obs.drop(annotation_list, inplace=True, errors="ignore", axis=1)


Expand Down Expand Up @@ -289,7 +310,16 @@ def plot_cell_annotation(
fig.suptitle(annotation)


def plot_genes_sc(genes, adata_measured, adata_predicted, cmap="inferno", perc=0):
def plot_genes_sc(
genes,
adata_measured,
adata_predicted,
spot_size=None,
scale_factor=None,
cmap="inferno",
perc=0,
return_figure=False
):

# remove df_plot in obs
adata_measured.obs.drop(
Expand Down Expand Up @@ -350,11 +380,17 @@ def plot_genes_sc(genes, adata_measured, adata_predicted, cmap="inferno", perc=0

fig = plt.figure(figsize=(7, len(genes) * 3.5))
gs = GridSpec(len(genes), 2, figure=fig)

#non visium data
if ("spatial" not in adata_measured.uns.keys()) and (spot_size==None and scale_factor==None):
raise ValueError("Spot Size and Scale Factor cannot be None when ad_sp.uns['spatial'] does not exist")

for ix, gene in enumerate(genes):

ax_m = fig.add_subplot(gs[ix, 0])
sc.pl.spatial(
adata_measured,
spot_size=spot_size,
scale_factor=scale_factor,
color=["{} (measured)".format(gene)],
frameon=False,
ax=ax_m,
Expand All @@ -364,13 +400,15 @@ def plot_genes_sc(genes, adata_measured, adata_predicted, cmap="inferno", perc=0
ax_p = fig.add_subplot(gs[ix, 1])
sc.pl.spatial(
adata_predicted,
spot_size=spot_size,
scale_factor=scale_factor,
color=["{} (predicted)".format(gene)],
frameon=False,
ax=ax_p,
show=False,
cmap=cmap,
)

# sc.pl.spatial(adata_measured, color=['{} (measured)'.format(gene) for gene in genes], frameon=False)
# sc.pl.spatial(adata_predicted, color=['{} (predicted)'.format(gene) for gene in genes], frameon=False)

Expand All @@ -387,6 +425,8 @@ def plot_genes_sc(genes, adata_measured, adata_predicted, cmap="inferno", perc=0
errors="ignore",
axis=1,
)
if return_figure==True:
return fig


def plot_genes(
Expand Down Expand Up @@ -631,8 +671,7 @@ def plot_auc(df_all_genes, test_genes=None):
textstr = 'auc_score={}'.format(np.round(metric_dict['auc_score'], 3))
props = dict(boxstyle='round', facecolor='wheat', alpha=0.3)
# place a text box in upper left in axes coords
plt.text(0.03, 0.1, textstr, fontsize=11,
verticalalignment='top', bbox=props);
plt.text(0.03, 0.1, textstr, fontsize=11, verticalalignment='top', bbox=props);


# Colors used in the manuscript for deterministic assignment.
Expand Down
Loading

0 comments on commit de0d884

Please sign in to comment.