Skip to content

Commit

Permalink
Added embedding_atlas to speed visualized 1M cell atlas
Browse files Browse the repository at this point in the history
  • Loading branch information
Starlitnightly committed Nov 29, 2024
1 parent 6a1599c commit 60a2750
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 9 deletions.
3 changes: 2 additions & 1 deletion omicverse/externel/BINARY/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter
from torch_sparse import SparseTensor, set_diag

from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
Expand Down Expand Up @@ -98,6 +98,7 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
# type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa
# type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
# type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
from torch_sparse import SparseTensor, set_diag
r"""
Args:
return_attention_weights (bool, optional): If set to :obj:`True`,
Expand Down
1 change: 0 additions & 1 deletion omicverse/externel/BINARY/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
Expand Down
3 changes: 2 additions & 1 deletion omicverse/pl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
from ._bulk import *
from ._space import *
from ._cpdb import *
from ._flowsig import *
from ._flowsig import *
from ._embedding import *
149 changes: 149 additions & 0 deletions omicverse/pl/_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc

def embedding_atlas(adata,basis,color,
title=None,figsize=(4,4),ax=None,cmap='RdBu',
legend_loc = 'right margin',frameon='small',
fontsize=12):
import scanpy as sc
import pandas as pd
import datashader as ds
import datashader.transfer_functions as tf
from scipy.sparse import issparse
from bokeh.palettes import RdBu9
import bokeh
# 创建一个 Canvas 对象
cvs = ds.Canvas(plot_width=800, plot_height=800)


embedding = adata.obsm[basis]
# 如果你有一个感兴趣的分类标签,比如细胞类型

# 将数据转换为 DataFrame
df = pd.DataFrame(embedding, columns=['x', 'y'])

if color in adata.obs.columns:
labels = adata.obs[color].tolist() # 假设'cell_type'是一个列名
elif color in adata.var_names:
X=adata[:,color].X
if issparse(X):
labels=X.toarray().reshape(-1)
else:
labels=X.reshape(-1)
elif (not adata.raw is None) and (color in adata.raw.var_names):
X=adata.raw[:,color].X
if issparse(X):
labels=X.toarray().reshape(-1)
else:
labels=X.reshape(-1)


df['label'] = labels
#return labels
#print(labels[0],type(labels[0]))
if type(labels[0]) is str:
df['label']=df['label'].astype('category')
# 聚合数据
agg = cvs.points(df, 'x', 'y',ds.count_cat('label'),
)
legend_tag=True
color_key = dict(zip(adata.obs[color].cat.categories,
adata.uns[f'{color}_colors']))


# 使用色彩映射
img = tf.shade(tf.spread(agg,px=0),color_key=[color_key[i] for i in df['label'].cat.categories],
how='eq_hist')
elif (type(labels[0]) is int) or (type(labels[0]) is float) or (type(labels[0]) is np.float32)\
or (type(labels[0]) is np.float64) or (type(labels[0]) is np.int):
# 聚合数据
agg = cvs.points(df, 'x', 'y',ds.mean('label'),
)
legend_tag=False
if cmap in bokeh.palettes.all_palettes.keys():
num=list(bokeh.palettes.all_palettes[cmap].keys())[-1]
img = tf.shade(agg,cmap=bokeh.palettes.all_palettes[cmap][num],
)
else:
img = tf.shade(agg,cmap=cmap,
)
else:
print('Unrecognized label type')
return None



# 假设 img 是 Datashader 渲染的图像
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig=ax.figure

# 假设 img 是一个 NumPy 数组或类似的对象,这里使用 img 的占位符
# img = np.random.rand(100, 100) # 示例数据
ax.imshow(img.to_pil(), aspect='auto')


# 自定义格式化函数以显示坐标
def format_coord(x, y):
return f"x={x:.2f}, y={y:.2f}"

ax.format_coord = format_coord

if legend_tag==True:
# 手动创建图例
unique_labels = adata.obs[color].cat.categories

# 创建图例项
for label in unique_labels:
ax.scatter([], [], c=color_key[label], label=label)

if legend_loc == "right margin":
ax.legend(
frameon=False,
loc="center left",
bbox_to_anchor=(1, 0.5),
ncol=(1 if len(unique_labels) <= 14 else 2 if len(unique_labels) <= 30 else 3),
fontsize=fontsize-1,
)
if frameon==False:
ax.axis('off')
elif frameon=='small':
ax.axis('on')
ax.set_xticks([])
ax.set_yticks([])
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_bounds(0,150)
ax.spines['left'].set_bounds(650,800)
ax.set_xlabel(f'{basis}1',loc='left',fontsize=fontsize)
ax.set_ylabel(f'{basis}2',loc='bottom',fontsize=fontsize)

else:
ax.axis('on')
ax.set_xticks([])
ax.set_yticks([])
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.set_xlabel(f'{basis}1',loc='center',fontsize=fontsize)
ax.set_ylabel(f'{basis}2',loc='center',fontsize=fontsize)


# 调整坐标轴线的粗细
line_width = 1.2 # 设置线宽
ax.spines['left'].set_linewidth(line_width)
ax.spines['bottom'].set_linewidth(line_width)


if title is None:
title=color
ax.set_title(title,fontsize=fontsize+1)

return ax
4 changes: 3 additions & 1 deletion omicverse/pl/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,9 @@ def violin_box(adata, keys, groupby, ax=None, figsize=(4,4), show=True, max_stri
ax.set_xlim(xlim)
ax.set_ylim(ylim)
#ax.set_xticklabels(ax.get_xticklabels(),rotation=90)
ax.get_legend().remove()
#remove legend
if ax.get_legend() is not None:
ax.get_legend().remove()
#ax.legend().set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
Expand Down
10 changes: 5 additions & 5 deletions omicverse/space/_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def clusters(adata,
return adata

def merge_cluster(adata,groupby='mclust',use_rep='STAGATE',
threshold=0.05,plot=True):
sc.tl.dendrogram(adata,groupby=groupby,use_rep=use_rep,)
threshold=0.05,plot=True,start_idx=0,**kwargs):
sc.tl.dendrogram(adata,groupby=groupby,use_rep=use_rep)
import numpy as np
from scipy.cluster.hierarchy import fcluster

Expand All @@ -330,14 +330,14 @@ def merge_cluster(adata,groupby='mclust',use_rep='STAGATE',

# 使用fcluster来合并类别
clusters = fcluster(linkage_matrix, threshold, criterion='distance')

# 创建字典
cluster_dict = {}
for idx, cluster_id in enumerate(clusters):
key = f'c{cluster_id}'
if key not in cluster_dict:
cluster_dict[key] = []
cluster_dict[key].append(idx+1)
cluster_dict[key].append(idx+start_idx)

reversed_dict = {}
for key, values in cluster_dict.items():
Expand All @@ -349,6 +349,6 @@ def merge_cluster(adata,groupby='mclust',use_rep='STAGATE',
adata.obs[f'{groupby}_tree']=adata.obs[groupby].map(reversed_dict)
print(f'The merged cluster information is stored in adata.obs["{groupby}_tree"].')
if plot:
ax=sc.pl.dendrogram(adata,groupby=groupby,show=False)
ax=sc.pl.dendrogram(adata,groupby=groupby,show=False,**kwargs)
ax.plot((ax.get_xticks().min(),ax.get_xticks().max()),(threshold,threshold))
return reversed_dict

0 comments on commit 60a2750

Please sign in to comment.