Skip to content

Commit

Permalink
Merge pull request #114 from DBinary/master
Browse files Browse the repository at this point in the history
The standardization correction of code related to space analysis.
  • Loading branch information
Starlitnightly authored Aug 22, 2024
2 parents bd21ba7 + f812b97 commit 3c76d7d
Show file tree
Hide file tree
Showing 9 changed files with 797 additions and 291 deletions.
263 changes: 263 additions & 0 deletions omicverse/pl/_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from pandas.api.types import CategoricalDtype, is_numeric_dtype
import seaborn as sns
import scanpy as sc
from scanpy.plotting._anndata import _prepare_dataframe
import pandas as pd
from anndata import AnnData
from ..utils import plotset

pycomplexheatmap_install=False
Expand Down Expand Up @@ -299,6 +301,267 @@ def complexheatmap(adata,
plt.show()
return cm

def marker_heatmap(
adata: AnnData,
marker_genes_dict: dict = None,
groupby: str = None,
color_map: str = "RdBu_r",
use_raw: bool = True,
standard_scale: str = "var",
expression_cutoff: float = 0.0,
bbox_to_anchor: tuple = (5, -0.5),
figsize: tuple = (8,4),
spines: bool = False,
fontsize: int = 12,
show_rownames: bool = True,
show_colnames: bool = True,
save_pathway: str = None,
ax=None,
):
"""
Parameters:
----------
adata: AnnData object
Annotated data matrix.
marker_genes_dict: dict
A dictionary containing the marker genes for each cell type.
groupby: str
The key in adata.obs that will be used for grouping the cells.
color_map: str
The color map to use for the value of heatmap.
use_raw: bool
Whether to use the raw data of AnnDta object for plotting.
standard_scale: str
The standard scale for the heatmap.
expression_cutoff: float
The cutoff value for the expression of genes.
bbox_to_anchor: tuple
The position of the legend bbox (x, y) in axes coordinates.
figsize: tuple
The size of the plot figure in inches (width, height).
spines: bool
Whether to show the spines of the plot.
fontsize: int
The font size of the text in the plot.
show_rownames: bool
Whether to show the row names in the heatmap.
show_colnames: bool
Whether to show the column names in the heatmap.
save_pathway: str
The file path for saving the plot.
ax: matplotlib.axes.Axes
A pre-existing axes object for plotting (optional).
Examples:
----------
marker_heatmap(
adata,
marker_genes_dict,
groupby='major_celltype',
color_map="RdBu_r",
use_raw=True,
standard_scale="var",
expression_cutoff=0.0,
fontsize=12,
bbox_to_anchor=(7, -0.5),
figsize=(8,4),
spines=False,
show_rownames=True,
show_colnames=True,
)
"""

# input check
if marker_genes_dict is None:
print("Please provide a dictionary containing the marker genes for each cell type.")
return
if groupby is None:
print("Please provide a key in adata.obs for grouping the cells.")
return

# pycomplexheatmap version check
try:
import PyComplexHeatmap as pch
from PyComplexHeatmap import DotClustermapPlotter,HeatmapAnnotation,anno_simple,anno_label,AnnotationBase
print('PyComplexHeatmap have been install version:',pch.__version__)
if pch.__version__ < '1.7.5':
raise ImportError(
'Please install PyComplexHeatmap with version > 1.7.5: `pip install PyComplexHeatmap`.'
)
except ImportError:
raise ImportError(
'Please install PyComplexHeatmap with version > 1.7.5: `pip install PyComplexHeatmap`.'
)

# Determine the color palette for different categories based on annotation data.
if f"{groupby}_colors" in adata.uns:
type_color_all = dict(zip(adata.obs[groupby].cat.categories,adata.uns[f"{groupby}_colors"]))
else:
if '{}_colors'.format(groupby) in adata.uns:
type_color_all=dict(zip(adata.obs[groupby].cat.categories,adata.uns['{}_colors'.format(groupby)]))
else:
if len(adata.obs[groupby].cat.categories)>28:
type_color_all=dict(zip(adata.obs[groupby].cat.categories,sc.pl.palettes.default_102))
else:
type_color_all=dict(zip(adata.obs[groupby].cat.categories,sc.pl.palettes.zeileis_28))

# Prepare lists to hold gene group labels and positions.
var_group_labels = []
_var_names = []
var_group_positions = []
start = 0
for label, vars_list in marker_genes_dict.items():
if isinstance(vars_list, str):
vars_list = [vars_list]
_var_names.extend(list(vars_list))
var_group_labels.append(label)
var_group_positions.append((start, start + len(vars_list) - 1))

# Prepare data for plotting using Scanpy's internal function.
categories, obs_tidy = _prepare_dataframe(
adata,
_var_names,
groupby=groupby,
use_raw=use_raw,
log=False,
num_categories=7,
layer=None,
gene_symbols=None,
)

# determine the dot size and calculate the mean expression and fraction of cells.
obs_bool = obs_tidy > expression_cutoff
dot_size_df = (
obs_bool.groupby(level=0, observed=True).sum()
/ obs_bool.groupby(level=0, observed=True).count()
)

# Standardize the expression values
dot_color_df = obs_tidy.groupby(level=0, observed=True).mean()
if standard_scale == "group":
dot_color_df = dot_color_df.sub(dot_color_df.min(1), axis=0)
dot_color_df = dot_color_df.div(dot_color_df.max(1), axis=0).fillna(0)
elif standard_scale == "var":
dot_color_df -= dot_color_df.min(0)
dot_color_df = (dot_color_df / dot_color_df.max(0)).fillna(0)
elif standard_scale is None:
pass

# Data preparation for pycomplexheatmap
Gene_list = []
for celltype in marker_genes_dict.keys():
for gene in marker_genes_dict[celltype]:
Gene_list.append(gene)

# Prepare data for complex heatmap plotting.
df_row=dot_color_df.index.to_frame()
df_row['Celltype']=dot_color_df.index
df_row.set_index('Celltype',inplace=True)
df_row.columns = ['Celltype_name']
df_row = df_row.loc[list(marker_genes_dict.keys()),:]

df_col = pd.DataFrame()
for celltype in marker_genes_dict.keys():
df_col_tmp=pd.DataFrame(index = marker_genes_dict[celltype])
df_col_tmp['Gene']=marker_genes_dict[celltype]
df_col_tmp['Celltype_name'] = celltype
df_col = pd.concat([df_col,df_col_tmp])
df_col.columns = ['Gene_name','Celltype_name']
df_col = df_col.loc[Gene_list,:]

# Create a melted DataFrame for color and size data.
color_df = pd.melt(dot_color_df.reset_index(), id_vars=groupby, var_name='gene', value_name='Mean\nexpression\nin group')
color_df[groupby] = color_df[groupby].astype(str)
color_df.index = color_df[groupby]+'_'+color_df['gene']
size_df = pd.melt(dot_size_df.reset_index(), id_vars=groupby, var_name='gene', value_name='Fraction\nof cells\nin group')
size_df[groupby] = size_df[groupby].astype(str)
size_df.index = size_df[groupby]+'_'+size_df['gene']
color_df['Fraction\nof cells\nin group'] = size_df.loc[color_df.index.tolist(),'Fraction\nof cells\nin group']

Gene_color = []
for celltype in df_row.Celltype_name:
for gene in marker_genes_dict[celltype]:
Gene_color.append(type_color_all[celltype])

# plot the complex heatmap
if ax==None:
fig, ax = plt.subplots(figsize=figsize)
else:
ax=ax

row_ha = HeatmapAnnotation(
TARGET=anno_simple(
df_row.Celltype_name,
colors=[type_color_all[i] for i in df_row.Celltype_name],
add_text=False,
text_kws={'color': 'black', 'rotation': 0,'fontsize':fontsize},
legend=False # 设置为 True 以显示行的图例
),
legend_gap=7,
axis=0,
verbose=0,
#label_side='left',
label_kws={'rotation': 90, 'horizontalalignment': 'right','fontsize':0},
)

col_ha = HeatmapAnnotation(
TARGET=anno_simple(
df_col.Gene_name,
colors=Gene_color,
add_text=False,
text_kws={'color': 'black', 'rotation': 0,'fontsize':fontsize},
legend=False # 设置为 True 以显示行的图例
),
verbose=0,
label_kws={'horizontalalignment': 'right','fontsize':0},
legend_kws={'ncols': 1}, # 调整图例的列数为1
legend=False,
legend_hpad=7,
legend_vpad=5,
axis=1,
)

cm = DotClustermapPlotter(color_df,y=groupby,x='gene',value='Mean\nexpression\nin group',
c='Mean\nexpression\nin group',s='Fraction\nof cells\nin group',cmap=color_map,
vmin=0,
hue=groupby,
top_annotation=col_ha,left_annotation=row_ha,
row_dendrogram=False,col_dendrogram=False,
col_split_order=list(df_col.Celltype_name.unique()),
col_split=df_col.Celltype_name,col_split_gap=1,
# row_split=df_row.Celltype_name,row_split_gap=1,
x_order=df_col.Gene_name.unique(),y_order=df_col.Celltype_name.unique(),
row_cluster=False,col_cluster=False,
show_rownames=show_rownames,show_colnames=show_colnames,
col_names_side='left',spines=spines,grid='minor',
legend=True,)

# Adjust grid settings
cm.ax_heatmap.grid(which='minor', color='gray', linestyle='--', alpha=0.5)
cm.ax_heatmap.grid(which='major', color='black', linestyle='-', linewidth=0.5)
cm.cmap_legend_kws={'ncols': 1}
plt.grid(False)
plt.tight_layout() # 调整布局以适应所有组件

# legend plot
handles = [plt.Line2D([0], [0], color=type_color_all[cell], lw=4) for cell in type_color_all.keys()]
labels = type_color_all.keys()
# Add a legend to the right of the existing image
legend_kws={'fontsize':fontsize,'bbox_to_anchor':bbox_to_anchor,'loc':'center left',}
plt.legend(handles, labels,
borderaxespad=1, handletextpad=0.5, labelspacing=0.2,**legend_kws)

if save_pathway is None:
pass
else:
plt.savefig(save_pathway, dpi=300, bbox_inches='tight')

plt.tight_layout()
plt.show()

return fig,ax


def global_imports(modulename,shortname = None, asfunction = False):
if shortname is None:
shortname = modulename
Expand Down
Loading

0 comments on commit 3c76d7d

Please sign in to comment.