Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement and debug DDRTree based methods #564

Merged
merged 58 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
403d784
add time_series to init
Sichao25 Jul 12, 2023
2b05718
debug directed_pg data loading
Sichao25 Jul 12, 2023
dd71e50
debug directed_pg by changing DDRTree return
Sichao25 Jul 12, 2023
0e7d5b5
optimize directed_pg calculation
Sichao25 Jul 12, 2023
6e76908
debug construct_velocity_tree_py
Sichao25 Jul 13, 2023
ee0df98
create _find_nearest_vertex func
Sichao25 Jul 13, 2023
9e127c3
debug DDRTree input shape
Sichao25 Jul 14, 2023
a70e447
create get DDRTree order func
Sichao25 Jul 14, 2023
4460a7c
create project2MST func
Sichao25 Jul 14, 2023
c33eae1
create select_root_cell func
Sichao25 Jul 14, 2023
337c33f
create order_cells func
Sichao25 Jul 17, 2023
39ac0fd
debug select_root_cell
Sichao25 Jul 17, 2023
3e51577
debug get_order_from_DDRTree
Sichao25 Jul 17, 2023
f6224b7
debug project2MST and order cells
Sichao25 Jul 17, 2023
4677cb1
debug select_root_cell w given state
Sichao25 Jul 17, 2023
26778e8
minor code style fix
Sichao25 Jul 17, 2023
5715e98
update tool init file
Sichao25 Jul 17, 2023
8c3df55
add type hint
Sichao25 Jul 17, 2023
dd414d5
add more args for order_cells
Sichao25 Jul 17, 2023
4d09407
docstr for projection method
Sichao25 Jul 18, 2023
1e4d75b
create docstr for all order cell funcs
Sichao25 Jul 18, 2023
e1098d5
update error message
Sichao25 Jul 18, 2023
857f5c3
add logger
Sichao25 Jul 18, 2023
c1b6d30
create transition matrix computation func
Sichao25 Jul 19, 2023
a7276fe
debug pseudotime
Sichao25 Jul 19, 2023
89b1dc8
create construct_velocity_tree func
Sichao25 Jul 19, 2023
e587c06
use segment instead of cell order
Sichao25 Jul 19, 2023
d07ed41
optimize construct_velocity_tree
Sichao25 Jul 19, 2023
8aca5e5
optimize transition matrix calculation
Sichao25 Jul 19, 2023
e6982a8
polish coding style
Sichao25 Jul 19, 2023
6d655f8
create docstr for construct_velocity_tree
Sichao25 Jul 19, 2023
940e6ad
remove unused helper and add docstr for deprecated func
Sichao25 Jul 19, 2023
6563409
remove deprecated pseudotime function
Sichao25 Jul 20, 2023
0a0051e
add necessary blank line
Sichao25 Jul 20, 2023
8b6ad9f
debug select_root_cell
Sichao25 Sep 7, 2023
3cfea43
add init cell types to order_cells
Sichao25 Sep 7, 2023
343f116
update docstr for order_cells
Sichao25 Sep 7, 2023
8b5ef09
create plot_dim_reduced_direct_graph proto
Sichao25 Sep 7, 2023
53c8489
add piechart to pseudotime plot
Sichao25 Sep 7, 2023
9270ce2
add draw_node in pseudotime plot
Sichao25 Sep 11, 2023
63d9a53
debug piechart distribution
Sichao25 Sep 11, 2023
311613e
implement get_color_map utils
Sichao25 Sep 15, 2023
ee2a838
optimize colormap
Sichao25 Sep 15, 2023
1cb4723
create docstr for pseudotime plot
Sichao25 Sep 15, 2023
03ca4b3
Merge branch 'master' into tool
Sichao25 Sep 29, 2023
6de0a62
optimize the format of directed graph plot
Sichao25 Oct 4, 2023
a89a85f
color the node without pievhart
Sichao25 Oct 17, 2023
af64f12
debug the comparison
Sichao25 Oct 17, 2023
2f03915
debug velocity tree
Sichao25 Oct 17, 2023
ac3600b
debug tree segmentation
Sichao25 Oct 18, 2023
0ff445f
debug center transition matrix
Sichao25 Oct 18, 2023
2d0a2f4
debug direction assignment
Sichao25 Oct 19, 2023
3ef6f8d
Merge branch 'master' into tool
Sichao25 Oct 19, 2023
5f1a251
debug plotting directed velocity tree
Sichao25 Oct 19, 2023
fac4788
rename funcs and debug cell orders
Sichao25 Oct 19, 2023
6a2cc7a
deal with none node
Sichao25 Oct 20, 2023
d91d0bf
add variance_scale as parameters
Sichao25 Oct 20, 2023
35b9a6e
debug piechart arrows
Sichao25 Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dynamo/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
show_fraction,
variance_explained,
)
from .pseudotime import plot_dim_reduced_direct_graph
from .scatters import scatters
from .scPotential import show_landscape
from .sctransform import sctransform_plot_fit, plot_residual_var
Expand Down Expand Up @@ -153,4 +154,5 @@
"hessian",
"sctransform_plot_fit",
"plot_residual_var",
"plot_dim_reduced_direct_graph",
]
202 changes: 200 additions & 2 deletions dynamo/plot/pseudotime.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,213 @@
from typing import Any, Dict, Tuple
import math
from typing import Any, Dict, List, Optional, Tuple, Union

try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from anndata import AnnData
from scipy.sparse import csr_matrix

from ..tools.utils import update_dict
from .utils import save_fig
from .utils import get_color_map_from_labels, save_fig


def _calculate_cells_mapping(
adata: AnnData,
group_key: str,
cell_proj_closest_vertex: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, Dict]:
"""Calculate the distribution of cells in each node.

Args:
adata: the anndata object.
group_key: the key to locate the groups of each cell in adata.
cell_proj_closest_vertex: the mapping from each cell to the corresponding node.

Returns:
The size of each node, the percentage of each group in every node, and the color mapping of each group.
"""
cells_mapping_size = np.bincount(cell_proj_closest_vertex)
centroids_index = range(len(cells_mapping_size))

cell_type_info = pd.DataFrame({
"class": adata.obs[group_key].values,
"centroid": cell_proj_closest_vertex,
})

cell_color_map = get_color_map_from_labels(adata.obs[group_key].values)

cell_type_info = cell_type_info.groupby(['centroid', 'class']).size().unstack()
cell_type_info = cell_type_info.reindex(centroids_index, fill_value=0)
cells_mapping_percentage = cell_type_info.div(cells_mapping_size, axis=0)
cells_mapping_percentage = np.nan_to_num(cells_mapping_percentage.values)

cells_mapping_size = (cells_mapping_size / len(cell_proj_closest_vertex))
cells_mapping_size = [0.05 if s < 0.05 else s for s in cells_mapping_size]

return cells_mapping_size, cells_mapping_percentage, cell_color_map


def _scale_positions(positions: np.ndarray, variance_scale: int = 1.5) -> np.ndarray:
"""Scale an array representing to the matplotlib coordinates system and scale the variance if needed.

Args:
positions: the array representing the positions of the data to plot.
variance_scale: the value to scale the variance of data.

Returns:
The positions after scaling.
"""
min_value = np.min(positions)
max_value = np.max(positions)
pos = (positions - min_value) / (max_value - min_value)
mean = np.mean(pos, axis=0)
pos = (pos - mean) * variance_scale
return pos


def plot_dim_reduced_direct_graph(
adata: AnnData,
group_key: Optional[str] = "Cell_type",
graph: Optional[Union[csr_matrix, np.ndarray]] = None,
cell_proj_closest_vertex: Optional[np.ndarray] = None,
center_coordinates: Optional[np.ndarray] = None,
display_piechart: bool = True,
variance_scale: int = 1.5,
save_show_or_return: Literal["save", "show", "return"] = "show",
save_kwargs: Dict[str, Any] = {},
) -> Optional[plt.Axes]:
"""Plot the directed graph constructed velocity-guided pseudotime.

Args:
adata: the anndata object.
group_key: the key to locate the groups of each cell in adata.
graph: the directed graph to plot.
cell_proj_closest_vertex: the mapping from each cell to the corresponding node.
center_coordinates: the array representing the positions of the center nodes in the low dimensions. Only need
this when display_piechart is True.
display_piechart: whether to display piechart for each node.
variance_scale: the value to scale the variance of data. This function is employed to space out the pie charts
when they are positioned too closely to each other.
save_show_or_return: whether to save, show or return the plot.
save_kwargs: additional keyword arguments of plot saving.

Returns:
The plot of the directed graph or `None`.
"""

try:
if graph is None:
graph = adata.uns["directed_velocity_tree"]

if cell_proj_closest_vertex is None:
cell_proj_closest_vertex = adata.uns["cell_order"]["pr_graph_cell_proj_closest_vertex"]
except KeyError:
raise KeyError("Cell order data is missing. Please run `tl.order_cells()` first!")

cells_size, cells_percentage, cells_color_map = _calculate_cells_mapping(
adata=adata,
group_key=group_key,
cell_proj_closest_vertex=cell_proj_closest_vertex,
)

cells_colors = np.array([v for v in cells_color_map.values()])

fig, ax = plt.subplots(figsize=(6, 6))

G = nx.from_numpy_array(graph, create_using=nx.DiGraph)

center_coordinates = adata.uns["cell_order"]["Y"].T.copy() if center_coordinates is None else center_coordinates
pos = _scale_positions(center_coordinates, variance_scale=variance_scale)
pos_dict = {}
for i in range(len(pos)):
pos_dict[i] = pos[i]

if display_piechart:

for node in G.nodes:
attributes = cells_percentage[node]

if np.all(attributes == 0):
plt.pie(
[1],
center=pos[node],
colors=[[0, 0, 0, 1]],
radius=cells_size[node],
)
else:
valid_indices = np.where(attributes != 0)[0]
plt.pie(
attributes[valid_indices],
center=pos[node],
colors=cells_colors[valid_indices],
radius=cells_size[node],
)
g = nx.draw_networkx_edges(
G,
pos=pos_dict,
node_size=[s * len(cells_size) * 300 for s in cells_size],
arrows=True,
arrowstyle="->",
arrowsize=20,
ax=ax,
)

else:
dominate_colors = []

for node in G.nodes:
attributes = cells_percentage[node]
if np.all(attributes == 0):
dominate_colors.append([0, 0, 0, 1])
else:
max_idx = np.argmax(attributes)
dominate_colors.append(cells_colors[max_idx])

nx.draw_networkx_nodes(G, pos=pos_dict, node_color=dominate_colors, node_size=[s * len(cells_size) * 300 for s in cells_size], ax=ax)
g = nx.draw_networkx_edges(
G,
pos=pos_dict,
node_size=[s * len(cells_size) * 300 for s in cells_size],
arrows=True,
arrowstyle="->",
arrowsize=20,
ax=ax,
)

cells_color_map["None"] = np.array([0, 0, 0, 1])
plt.legend(handles=[plt.Line2D([0], [0], marker="o", color='w', label=label,
markerfacecolor=color) for label, color in cells_color_map.items()],
loc="best",
fontsize="medium",
)

if save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": "plot_dim_reduced_direct_graph",
"dpi": None,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, save_kwargs)

if save_show_or_return in ["both", "all"]:
s_kwargs["close"] = False

save_fig(**s_kwargs)
if save_show_or_return in ["show", "both", "all"]:
plt.tight_layout()
plt.show()
if save_show_or_return in ["return", "all"]:
return g


def plot_direct_graph(
Expand Down
2 changes: 2 additions & 0 deletions dynamo/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def is_list_of_lists(list_of_lists):

def get_color_map_from_labels(labels: np.ndarray, color_key_cmap: str = "glasbey_white") -> np.ndarray:
"""Generate a color map according to given labels.

Args:
labels: the label representing the groups of data.
color_key_cmap: the cmap used to generate the colors. Recommend 'glasbey_white'/'glasbey_black' for continuous
data, and 'inferno'/'viridis' for discrete data.

Returns:
The mapping of colors corresponding to each unique label.
"""
Expand Down
2 changes: 1 addition & 1 deletion dynamo/tools/DDRTree_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def DDRTree(
iterations.
"""

X = np.array(X)
X = np.array(X).T
(D, N) = X.shape

# initialization
Expand Down
3 changes: 3 additions & 0 deletions dynamo/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
)

# Pseudotime related
from .construct_velocity_tree import construct_velocity_tree, construct_velocity_tree_py
from .DDRTree_py import DDRTree, cal_ncenter
from .pseudotime import order_cells
from .time_series import directed_pg

# dimension reduction related
from .dimension_reduction import reduceDimension # , run_umap
Expand Down
Loading