Skip to content

Commit

Permalink
Fixed the error of get_results and get_results_rfc in cNMF modu…
Browse files Browse the repository at this point in the history
…le. (#143) (#139)
  • Loading branch information
Starlitnightly committed Sep 16, 2024
1 parent 4f0838a commit 32f8ffa
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 9 deletions.
13 changes: 9 additions & 4 deletions omicverse/externel/STT/pl/_plot_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
def plot_tensor_single(adata, adata_aggr = None, state = 'joint',
attractor = None, basis = 'umap', color ='attractor',
color_map = None, size = 20, alpha = 0.5, ax = None,
show = None, filter_cells = False, member_thresh = 0.05, density =2):
show = None, filter_cells = False, member_thresh = 0.05, density =2,
n_jobs = -1):
"""
Function to plot a single tensor graph with assgined components
Expand Down Expand Up @@ -66,18 +67,22 @@ def plot_tensor_single(adata, adata_aggr = None, state = 'joint',

if state == 'spliced':
adata.layers['vs'] = velo[:,gene_select,1]
scv.tl.velocity_graph(adata, vkey = 'vs', xkey = 'Ms',n_jobs = -1)
scv.tl.velocity_graph(adata, vkey = 'vs', xkey = 'Ms',n_jobs = n_jobs)
scv.pl.velocity_embedding_stream(adata, vkey = 'vs', basis=basis, color=color, title = title+','+'Spliced',color_map = color_map, size = size, alpha = alpha, ax = ax, show = show)
if state == 'unspliced':
adata.layers['vu'] = velo[:,gene_select,0]
scv.tl.velocity_graph(adata, vkey = 'vu', xkey = 'Mu',n_jobs = -1)
scv.tl.velocity_graph(adata, vkey = 'vu', xkey = 'Mu',n_jobs = n_jobs)
scv.pl.velocity_embedding_stream(adata, vkey = 'vu',basis=basis, color=color, title = title+','+'Unspliced',color_map = color_map, size = size, alpha = alpha, ax = ax, show = show)
if state == 'joint':
print("check that the input includes aggregated object")
#adata_aggr.layers['vj'] = np.concatenate((velo[:,gene_select,0],velo[:,gene_select,1]),axis = 1)
scv.tl.velocity_graph(adata_aggr, vkey = 'vj', xkey = 'Ms',n_jobs = -1)
scv.tl.velocity_graph(adata_aggr, vkey = 'vj', xkey = 'Ms',n_jobs = n_jobs)
scv.pl.velocity_embedding_stream(adata_aggr, vkey = 'vj',basis=basis, color=color, title = title+','+'Joint',color_map = color_map, size = size, alpha = alpha, ax = ax, show = show, density =density)

del adata_copy
del adata_aggr_copy
import gc
gc.collect()


def plot_tensor(adata, adata_aggr, list_state =['joint','spliced','unspliced'], list_attractor ='all', basis = 'umap',figsize = (8,8),hspace = 0.2,wspace = 0.2, color_map = None,size = 20,alpha = 0.5, filter_cells = False, member_thresh = 0.05, density =2):
Expand Down
2 changes: 1 addition & 1 deletion omicverse/externel/STT/tl/_construct_landscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def construct_landscape(sc_object,thresh_cal_cov = 0.3, scale_axis = 1.0, scale_
mu_hat = sc_object.uns['da_out']['mu_hat']
rho = sc_object.obsm['rho']
projection = sc_object.obsm[coord_key][:,0:2]
p_hat=adata.uns['da_out']['P_hat']
p_hat=sc_object.uns['da_out']['P_hat']


labels = np.argmax(rho,axis = 1)
Expand Down
62 changes: 62 additions & 0 deletions omicverse/externel/cnmf/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,68 @@ def get_results_rfc(self,adata,result_dict,use_rep='STAGATE',cNMF_threshold=0.5)
print("Single Tree:",score_c)
print("Random Forest:",score_r)

adata.obs['cNMF_cluster_rfc']=[str(i) for i in rfc.predict(adata.obsm[use_rep])]
adata.obs['cNMF_cluster_clf']=[str(i) for i in clf.predict(adata.obsm[use_rep])]
print('cNMF_cluster_rfc is added to adata.obs')
print('cNMF_cluster_clf is added to adata.obs')

def get_results(self,adata,result_dict):
import pandas as pd
if result_dict['usage_norm'].columns[0] in adata.obs.columns:
#remove the columns if they already exist
#remove columns name starts with 'cNMF'
adata.obs = adata.obs.loc[:,~adata.obs.columns.str.startswith('cNMF')]
adata.obs = pd.merge(left=adata.obs, right=result_dict['usage_norm'],
how='left', left_index=True, right_index=True)
adata.var = pd.merge(left=adata.var,right=result_dict['gep_scores'].loc[adata.var.index],
how='left', left_index=True, right_index=True)
df=adata.obs[result_dict['usage_norm'].columns].copy()
max_topic = df.idxmax(axis=1)
# 将结果添加到DataFrame中
adata.obs['cNMF_cluster'] = max_topic
print('cNMF_cluster is added to adata.obs')
print('gene scores are added to adata.var')

def get_results_rfc(self,adata,result_dict,use_rep='STAGATE',cNMF_threshold=0.5):
import pandas as pd
if result_dict['usage_norm'].columns[0] in adata.obs.columns:
#remove the columns if they already exist
#remove columns name starts with 'cNMF'
adata.obs = adata.obs.loc[:,~adata.obs.columns.str.startswith('cNMF')]
adata.obs = pd.merge(left=adata.obs, right=result_dict['usage_norm'],
how='left', left_index=True, right_index=True)
adata.var = pd.merge(left=adata.var,right=result_dict['gep_scores'].loc[adata.var.index],
how='left', left_index=True, right_index=True)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

new_array = []
class_array = []
for i in range(1, result_dict['usage_norm'].shape[1] + 1):
data = adata[adata.obs[f'cNMF_{i}'] > cNMF_threshold].obsm[use_rep].toarray()
new_array.append(data)
class_array.append(np.full(data.shape[0], i))

new_array = np.concatenate(new_array, axis=0)
class_array = np.concatenate(class_array)

Xtrain, Xtest, Ytrain, Ytest = train_test_split(new_array,class_array,test_size=0.3)
clf = DecisionTreeClassifier(random_state=0)
rfc = RandomForestClassifier(random_state=0)
clf = clf.fit(Xtrain,Ytrain)
rfc = rfc.fit(Xtrain,Ytrain)
#查看模型效果
score_c = clf.score(Xtest,Ytest)
score_r = rfc.score(Xtest,Ytest)
#打印最后结果
print("Single Tree:",score_c)
print("Random Forest:",score_r)

adata.obs['cNMF_cluster_rfc']=[str(i) for i in rfc.predict(adata.obsm[use_rep])]
adata.obs['cNMF_cluster_clf']=[str(i) for i in clf.predict(adata.obsm[use_rep])]
print('cNMF_cluster_rfc is added to adata.obs')
Expand Down
9 changes: 5 additions & 4 deletions omicverse/pl/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def embedding(
def cellproportion(adata:AnnData,celltype_clusters:str,groupby:str,
groupby_li=None,figsize:tuple=(4,6),
ticks_fontsize:int=12,labels_fontsize:int=12,ax=None,
legend:bool=False):
legend:bool=False,legend_awargs=None):
"""
Plot cell proportion of each cell type in each visual cluster.
Expand Down Expand Up @@ -192,7 +192,7 @@ def cellproportion(adata:AnnData,celltype_clusters:str,groupby:str,
bottoms+=test1['value'].values
n+=1
if legend!=False:
plt.legend(bbox_to_anchor=(1.05, -0.05), loc=3, borderaxespad=0,fontsize=10)
plt.legend(bbox_to_anchor=(1.05, -0.05), loc=3, borderaxespad=0,fontsize=10,**legend_awargs)

plt.grid(False)

Expand Down Expand Up @@ -955,7 +955,7 @@ def plot_boxplots( # pragma: no cover
def cellstackarea(adata,celltype_clusters:str,groupby:str,
groupby_li=None,figsize:tuple=(4,6),
ticks_fontsize:int=12,labels_fontsize:int=12,ax=None,
legend:bool=False):
legend:bool=False,legend_awargs=None):
"""
Plot the cell type percentage in each groupby category
Expand Down Expand Up @@ -1006,7 +1006,8 @@ def cellstackarea(adata,celltype_clusters:str,groupby:str,
bottom += pivot_df[cell_type]

if legend!=False:
plt.legend(bbox_to_anchor=(1.05, -0.05), loc=3, borderaxespad=0,fontsize=10)
plt.legend(bbox_to_anchor=(1.05, -0.05), loc=3, borderaxespad=0,
fontsize=labels_fontsize,**legend_awargs)

plt.grid(False)

Expand Down
15 changes: 15 additions & 0 deletions omicverse_guide/docs/Release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,18 @@ Support Raw Windows platform
- Fixed type error of `float128` #134


## v 1.6.7

### Space Module

- Added `n_jobs` argument to adjust thread in `extenel.STT.pl.plot_tensor_single`
- Fixed an error in `extenel.STT.tl.construct_landscape`


### Pl Module

- Added `legend_awargs` to adjust the legend set in `pl.cellstackarea` and `pl.cellproportion`

### Single Module

- Fixed the error of `get_results` and `get_results_rfc` in `cNMF` module. (#143) (#139)

0 comments on commit 32f8ffa

Please sign in to comment.