"""
Core plasticity simulation module.
This module provides functions for simulating different types of cellular plasticity
in single-cell datasets, including random walk plasticity and cluster-based switching.
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import networkx as nx
try:
import walker
HAS_WALKER = True
except ImportError:
HAS_WALKER = False
walker = None
import scanpy as sc
try:
from ete3 import Tree, TreeStyle, NodeStyle, TextFace
except ImportError:
# Handle different ete3 import structures
try:
from ete3 import Tree
from ete3.treeview import TreeStyle, NodeStyle, TextFace
except ImportError:
from ete3 import Tree
TreeStyle = NodeStyle = TextFace = None
from typing import Dict, List, Optional, Tuple, Union
import anndata
[docs]
def random_walk_plasticity(
full_simulated_ad: anndata.AnnData,
subset_simulated_ad: anndata.AnnData,
plastic_cells: Dict[str, List[str]],
walk_lengths: Dict[str, int],
latent_space_key: str = 'X_dc'
) -> anndata.AnnData:
"""
Simulate random walk plasticity in single cells.
Performs random walks on specified plastic cells to simulate phenotypic transitions.
Plastic cells from different leiden clusters perform walks of specified lengths,
and their phenotypes are replaced with their walk targets.
Parameters
----------
full_simulated_ad : anndata.AnnData
Complete simulated dataset used for performing random walks.
subset_simulated_ad : anndata.AnnData
Subset of the dataset containing cells to be analyzed.
plastic_cells : Dict[str, List[str]]
Dictionary mapping leiden cluster identifiers to lists of cell names
that will undergo plastic transitions.
walk_lengths : Dict[str, int]
Dictionary mapping leiden cluster identifiers to walk lengths.
Must contain entries for all keys in plastic_cells.
latent_space_key : str, optional
Key in the obsm attribute of the AnnData object that contains the latent space representation.
Default is 'X_dc'.
Returns
-------
anndata.AnnData
Modified dataset with plastic cells replaced by their walk targets.
Non-plastic cells remain unchanged.
Raises
------
AssertionError
If walk_lengths keys don't match plastic_cells keys.
Examples
--------
>>> plastic_cells = {'0': ['Cell_1', 'Cell_2'], '1': ['Cell_3']}
>>> walk_lengths = {'0': 100, '1': 50}
>>> result = random_walk_plasticity(full_ad, subset_ad, plastic_cells, walk_lengths)
Notes
-----
This function modifies cell phenotypes by replacing plastic cells with cells
that represent their final positions after random walks. The walk parameters
(p, q) are set within the perform_random_walk function.
"""
assert set(walk_lengths.keys()) == set(plastic_cells.keys()), "Walk lengths must be specified for each selected leiden cluster"
all_plastic_cells = [cell for cells in plastic_cells.values() for cell in cells]
updated_ad_list = []
for lc in plastic_cells:
print(f"Leiden cluster {lc}: {len(plastic_cells[lc])} plastic cells will perform random walks of length {walk_lengths[lc]}")
input_ad = full_simulated_ad.copy()
# Allow these cells to perform a random walk
targets, change_in_phenotype, walks = perform_random_walk(
input_ad,
plastic_cells=plastic_cells[lc],
walk_length=walk_lengths[lc],
latent_space_key=latent_space_key,
)
# We will replace the phenotypes of these plastic cells with their target phenotypes
updated_phenotypes = input_ad[targets['target']].copy()
updated_phenotypes.obs_names = targets.index.values
updated_ad_list.append(updated_phenotypes)
# Combine all updated phenotypes with the non-plastic cells. We will remove all the plastic cells from the original ad and combine with the updated ones
non_plastic_cells = subset_simulated_ad.obs_names[~subset_simulated_ad.obs_names.isin(all_plastic_cells)].to_list()
non_plastic_ad = subset_simulated_ad[non_plastic_cells].copy()
# Concatenate the non-plastic cells with the updated plastic cells
final_ad = anndata.concat([non_plastic_ad] + updated_ad_list)
final_ad.obs['leiden'] = final_ad.obs['leiden'].astype(str)
# Remove the -0 suffix from obs_names
final_ad.obs_names = [x.split('-')[0] for x in final_ad.obs_names]
assert final_ad.shape[0] == subset_simulated_ad.shape[0], "Final dataset should have the same number of cells as the input subset"
assert set(final_ad.obs_names) == set(subset_simulated_ad.obs_names), "Final dataset should have the same cells as the input subset"
final_ad = final_ad[final_ad.obs_names]
subset_simulated_ad = subset_simulated_ad[final_ad.obs_names]
# Compute change_in_phenotype: Euclidean distance between original and new latent space coordinates
if latent_space_key in subset_simulated_ad.obsm and latent_space_key in final_ad.obsm:
# Since we know the datasets have identical cell names, we can align directly
original_coords = subset_simulated_ad.obsm[latent_space_key]
new_coords = final_ad.obsm[latent_space_key]
# Vectorized computation for all cells at once
change_distances = np.linalg.norm(new_coords - original_coords, axis=1)
# Assign distances to all cells
final_ad.obs['change_in_phenotype'] = change_distances
return final_ad
[docs]
def get_distances_of_moves(
G: nx.Graph,
sources: List[str],
targets: List[str]
) -> np.ndarray:
"""
Compute shortest path distances between source and target nodes.
Parameters
----------
G : networkx.Graph
Graph representing phenotypic connectivity.
sources : List[str]
List of source node names.
targets : List[str]
List of target node names.
Returns
-------
np.ndarray
Array of shortest path distances between corresponding source-target pairs.
Raises
------
AssertionError
If sources and targets have different lengths.
"""
if len(sources) != len(targets):
raise ValueError("Sources and targets must have the same length")
dists = []
for source, target in zip(sources, targets):
dists.append(len(nx.shortest_path(G, source=source, target=target)))
return np.array(dists)
[docs]
def construct_phenotypic_graph(
ad: anndata.AnnData,
latent_space_key: str,
n_nbrs: int = 10
) -> nx.Graph:
"""
Construct a k-nearest neighbor graph from latent space coordinates.
Parameters
----------
ad : anndata.AnnData
Annotated data object.
latent_space_key : str
Key in `ad.obsm` containing latent space coordinates.
n_nbrs : int, optional
Number of nearest neighbors for graph construction, by default 10.
Returns
-------
networkx.Graph
Graph where nodes are cells and edges connect nearest neighbors.
"""
ad = ad.copy()
sc.pp.neighbors(ad, n_neighbors=n_nbrs, use_rep=latent_space_key)
F = graph_from_connectivities(ad.obsp['connectivities'], ad.obs_names)
return F
[docs]
def graph_from_connectivities(adj_matrix, cell_names: List[str]) -> nx.Graph:
"""
Convert sparse adjacency matrix to NetworkX graph.
Parameters
----------
adj_matrix : scipy.sparse.csr_matrix
Sparse adjacency matrix where 1 indicates connected nodes.
cell_names : List[str]
List of cell names corresponding to matrix rows/columns.
Returns
-------
networkx.Graph
Graph with nodes labeled by cell names.
Raises
------
AttributeError
If adj_matrix is not a supported sparse matrix type.
"""
from scipy.sparse import csr_matrix
if isinstance(adj_matrix, csr_matrix):
H = nx.from_scipy_sparse_array(adj_matrix)
nx.relabel_nodes(H, dict(zip(H.nodes(), cell_names)), copy=False)
return H
else:
raise AttributeError(f'graph_from_connectivities not implemented for {type(adj_matrix)}')
[docs]
def visualize_walk(
ad: anndata.AnnData,
walk_indices: np.ndarray,
save_to: Optional[str] = None,
show_plots: bool = False
) -> None:
"""
Visualize a random walk path on UMAP coordinates.
Creates a scatter plot showing the path of a single random walk,
with origin and target cells highlighted.
Parameters
----------
ad : anndata.AnnData
AnnData object containing UMAP coordinates in `obsm['X_umap']`.
walk_indices : np.ndarray
1D array of cell indices representing the walk path.
save_to : str, optional
Directory to save the plot, by default None.
show_plots : bool, optional
Whether to display plots interactively, by default False.
Examples
--------
>>> _, _, walks = perform_random_walk(ad, ['Cell_1'])
>>> visualize_walk(ad, walks[0])
Notes
-----
- Requires UMAP coordinates in ad.obsm['X_umap']
- Origin cell is marked with a black star
- Target cell is marked with a red star
- Intermediate cells are colored with a gradient
"""
if 'X_umap' not in ad.obsm.keys():
raise KeyError('UMAP coordinates not found in AnnData object')
if not isinstance(walk_indices, np.ndarray):
raise TypeError('walk_indices must be a numpy array')
if len(walk_indices.shape) != 1 or walk_indices.shape[0] == 0:
if walk_indices.shape[1] == 1:
walk_indices = walk_indices.flatten()
elif walk_indices.shape[0] == 1:
walk_indices = walk_indices.flatten()
else:
raise ValueError('walk_indices must be a 1D array')
umap_coords = ad.obsm['X_umap']
colors = sns.color_palette("Wistia", len(walk_indices))
plt.figure(figsize=(10, 8))
plt.scatter(umap_coords[:, 0], umap_coords[:, 1], color='grey', s=1, alpha=0.5)
# Plot walk path
for idx, cell_idx in enumerate(walk_indices):
if idx == 0:
continue
elif idx == len(walk_indices) - 1:
plt.scatter(umap_coords[cell_idx, 0], umap_coords[cell_idx, 1],
s=100, color='red', label='Target', marker='*')
else:
plt.scatter(umap_coords[cell_idx, 0], umap_coords[cell_idx, 1],
s=5, color=colors[idx])
# Plot origin
cell_idx = walk_indices[0]
plt.scatter(umap_coords[cell_idx, 0], umap_coords[cell_idx, 1],
s=100, color='black', label='Origin', marker='*')
plt.title('UMAP plot with Highlighted Walk')
plt.legend()
plt.xlabel('UMAP 1')
plt.ylabel('UMAP 2')
if save_to is not None:
plt.savefig(os.path.join(save_to, 'random_walk.svg'), bbox_inches='tight', dpi=300)
print(f'Saved plot to {os.path.join(save_to, "random_walk.svg")}')
if show_plots:
plt.show()
plt.close()
[docs]
def cluster_switch_plasticity(
full_simulated_ad: anndata.AnnData,
subset_simulated_ad: anndata.AnnData,
plastic_cells: Dict[str, Dict[str, List[str]]],
column: str = 'leiden',
latent_space_key: str = 'X_dc'
) -> anndata.AnnData:
"""
Simulate plasticity through direct cluster switches.
Replaces plastic cells with randomly selected cells from their target clusters,
simulating direct phenotypic transitions without intermediate states.
Parameters
----------
full_simulated_ad : anndata.AnnData
Complete simulated dataset used for selecting replacement cells.
subset_simulated_ad : anndata.AnnData
Subset of the dataset containing cells to be analyzed.
plastic_cells : Dict[str, Dict[str, List[str]]]
Dictionary mapping cluster IDs (as strings) to a dictionary with keys 'destination' (target cluster ID as string)
and 'cells' (list of cell names that will undergo plastic transitions).
column : str, optional
Column in obs containing cluster annotations, by default 'leiden'.
latent_space_key : str, optional
Key in `obsm` containing latent space coordinates for distance calculation, by default 'X_dc'.
Returns
-------
anndata.AnnData
Modified dataset with plastic cells replaced by cells from target clusters.
Includes 'change_in_phenotype' column with Euclidean distance changes in latent space.
Raises
------
ValueError
If plastic cells are not found in subset_simulated_ad.
Examples
--------
>>> plastic_cells = {'5': {'destination': '4', 'cells': ['Cell_1', 'Cell_2']}}
>>> result = cluster_switch_plasticity(full_ad, subset_ad, plastic_cells)
Notes
-----
This function directly replaces plastic cells with cells from random target
clusters, simulating abrupt phenotypic switches without gradual transitions.
All cluster IDs are treated as strings for consistency. The 'change_in_phenotype'
column contains the Euclidean distance between original and new positions in
latent space (0 for non-plastic cells).
"""
# Validate that all plastic cells exist in the subset
all_plastic_cells = sum([plastic_cells[cl]['cells'] for cl in plastic_cells], [])
missing = [c for c in all_plastic_cells if c not in subset_simulated_ad.obs_names]
if missing:
raise ValueError(f"Some plastic cells not found in subset_simulated_ad: {missing}")
updated_ad_list = []
for cluster in plastic_cells:
input_ad = full_simulated_ad.copy()
target_cluster = plastic_cells[cluster]['destination']
# Randomly select the correct number of target cells
target_cells = input_ad.obs_names[input_ad.obs[column] == target_cluster].to_list()
n_plastic = len(plastic_cells[cluster]['cells'])
if len(target_cells) < n_plastic:
print(f"Warning: Not enough cells in target cluster {target_cluster}. Limiting number of plastic cells to {len(target_cells)}.")
n_plastic = len(target_cells)
selected_target_cells = np.random.choice(target_cells, n_plastic, replace=False).tolist()
targets = pd.DataFrame({'cell': plastic_cells[cluster]['cells'],
'target': selected_target_cells}).set_index('cell')
# We will replace the phenotypes of these plastic cells with their target phenotypes
updated_phenotypes = input_ad[targets['target']].copy()
updated_phenotypes.obs_names = targets.index.values
updated_ad_list.append(updated_phenotypes)
# Combine all updated phenotypes with the non-plastic cells. We will remove all the plastic cells from the original ad and combine with the updated ones
non_plastic_cells = subset_simulated_ad.obs_names[~subset_simulated_ad.obs_names.isin(all_plastic_cells)].to_list()
non_plastic_ad = subset_simulated_ad[non_plastic_cells].copy()
# Concatenate the non-plastic cells with the updated plastic cells
final_ad = anndata.concat([non_plastic_ad] + updated_ad_list)
final_ad.obs['leiden'] = final_ad.obs['leiden'].astype(str)
# Remove the -0 suffix from obs_names
final_ad.obs_names = [x.split('-')[0] for x in final_ad.obs_names]
# Compute change_in_phenotype: Euclidean distance between original and new latent space coordinates
assert final_ad.shape[0] == subset_simulated_ad.shape[0], "Final dataset should have the same number of cells as the input subset"
assert set(final_ad.obs_names) == set(subset_simulated_ad.obs_names), "Final dataset should have the same cells as the input subset"
final_ad = final_ad[final_ad.obs_names]
subset_simulated_ad = subset_simulated_ad[final_ad.obs_names]
# Compute change_in_phenotype: Euclidean distance between original and new latent space coordinates
if latent_space_key in subset_simulated_ad.obsm and latent_space_key in final_ad.obsm:
# Since we know the datasets have identical cell names, we can align directly
original_coords = subset_simulated_ad.obsm[latent_space_key]
new_coords = final_ad.obsm[latent_space_key]
# Vectorized computation for all cells at once
change_distances = np.linalg.norm(new_coords - original_coords, axis=1)
# Assign distances to all cells
final_ad.obs['change_in_phenotype'] = change_distances
return final_ad
[docs]
def plot_leiden_transitions(full_simulated_ad: anndata.AnnData,
destination_clusters: dict,
show_plots: bool = True,
save_to: str = None
) -> None:
"""
Plots UMAP with arrows indicating transitions between specified leiden clusters.
Parameters
----------
full_simulated_ad : AnnData
The AnnData object containing the single-cell data with UMAP coordinates and leiden cluster annotations.
destination_clusters : dict
A dictionary where keys are source leiden cluster IDs (str) and values are dictionaries with keys:
- 'destination': target leiden cluster ID (str)
- 'proportion': proportion of cells to transition (float between 0 and 1)
show_plots : bool, optional
Whether to display the plot immediately. Default is True.
save_to : str, optional
If provided, the path to save the plot image. Default is None (do not save
"""
# Ensure UMAP has been computed
if 'X_umap' not in full_simulated_ad.obsm:
raise ValueError("UMAP coordinates not found in 'obsm'. Please compute UMAP before plotting.")
# Ensure leiden clustering has been performed
if 'leiden' not in full_simulated_ad.obs:
raise ValueError("Leiden clustering not found in 'obs'. Please perform clustering before plotting.")
##############
# Plotting
##############
# Plot without immediate display
sc.pl.umap(
full_simulated_ad,
color='leiden',
title='Leiden Clusters',
size=40,
frameon=False,
edges=True,
edges_color='black',
show=False
)
ax = plt.gca()
# Extract UMAP coordinates and Leiden labels
umap_coords = full_simulated_ad.obsm['X_umap']
leiden = full_simulated_ad.obs['leiden'].astype(str)
# Compute centroids
centroids = {
cluster: umap_coords[leiden == cluster].mean(axis=0)
for cluster in np.unique(leiden)
}
# Source → Target
for src in destination_clusters.keys():
tgt = destination_clusters[src]['destination']
src_pt, tgt_pt = centroids[src], centroids[tgt]
# Draw custom arrow
ax.annotate(
'',
xy=tgt_pt, xycoords='data',
xytext=src_pt, textcoords='data',
arrowprops=dict(
lw=2,
facecolor="black", # fill color
edgecolor="white", # outline color
mutation_scale=25, # bigger arrowhead
shrinkA=5, shrinkB=5 # space around arrow ends
)
)
plt.tight_layout()
if show_plots:
plt.show()
if save_to:
plt.savefig(save_to)
plt.close()
return
[docs]
def plot_change_in_phenotype(
ad: anndata.AnnData,
plastic_ad: anndata.AnnData,
all_plastic_cells: List[str],
latent_space_key: str = 'X_dc',
show_plots: bool = False,
save_to: Optional[str] = None
) -> pd.Series:
"""
Visualize phenotypic changes in plastic cells before and after plasticity simulation.
Creates a three-panel plot showing:
1. Original data with plastic cells highlighted
2. Data after plasticity with plastic cells highlighted
3. Distribution of phenotypic change distances
Parameters
----------
ad : anndata.AnnData
Original AnnData object before plasticity simulation.
plastic_ad : anndata.AnnData
AnnData object after plasticity simulation.
all_plastic_cells : List[str]
List of cell names that underwent plastic transitions.
latent_space_key : str, optional
Key in `obsm` containing latent space coordinates for distance calculation, by default 'X_dc'.
show_plots : bool, optional
Whether to display plots interactively, by default False.
save_to : str, optional
Directory to save the plot, by default None.
Returns
-------
pd.Series
Series containing phenotypic change distances for each plastic cell.
Raises
------
KeyError
If required keys are not found in the AnnData objects.
ValueError
If plastic cells are not found in both datasets.
Examples
--------
>>> change_distances = plot_change_in_phenotype(
... original_ad, plastic_ad, ['Cell_1', 'Cell_2'], show_plots=True
... )
>>> print(f"Mean change: {change_distances.mean():.3f}")
Notes
-----
- Requires UMAP coordinates in both AnnData objects
- Requires latent space coordinates for distance calculation
- Plastic cells must be present in both original and plastic datasets
"""
# Validate inputs
if 'X_umap' not in ad.obsm:
raise KeyError("UMAP coordinates not found in original data. Run sc.tl.umap() first.")
if 'X_umap' not in plastic_ad.obsm:
raise KeyError("UMAP coordinates not found in plastic data. Run sc.tl.umap() first.")
if latent_space_key not in ad.obsm:
raise KeyError(f"Latent space key '{latent_space_key}' not found in original data.")
if latent_space_key not in plastic_ad.obsm:
raise KeyError(f"Latent space key '{latent_space_key}' not found in plastic data.")
if 'leiden' not in ad.obs:
raise KeyError("Leiden clustering not found in original data. Run sc.tl.leiden() first.")
if 'leiden' not in plastic_ad.obs:
raise KeyError("Leiden clustering not found in plastic data. Run sc.tl.leiden() first.")
# Check if plastic cells exist in both datasets
original_cells = set(ad.obs_names)
plastic_cells_set = set(all_plastic_cells)
plastic_dataset_cells = set(plastic_ad.obs_names)
missing_original = plastic_cells_set - original_cells
missing_plastic = plastic_cells_set - plastic_dataset_cells
if missing_original:
raise ValueError(f"Plastic cells not found in original dataset: {list(missing_original)[:5]}...")
if missing_plastic:
raise ValueError(f"Plastic cells not found in plastic dataset: {list(missing_plastic)[:5]}...")
# Compute diffusion change distances
prev_coords = pd.DataFrame(ad.obsm[latent_space_key], index=ad.obs_names).loc[all_plastic_cells]
new_coords = pd.DataFrame(plastic_ad.obsm[latent_space_key], index=plastic_ad.obs_names).loc[all_plastic_cells]
change_distances = np.linalg.norm(new_coords.values - prev_coords.values, axis=1)
change_distances_series = pd.Series(change_distances, index=all_plastic_cells, name='phenotypic_change_distance')
# Set up custom subplot widths
fig, axs = plt.subplots(
1, 3, figsize=(20, 6),
gridspec_kw={"width_ratios": [2, 2, 1]} # make last panel narrower
)
# Panel 1: original data with plastic cells outlined
sc.pl.umap(
ad,
color="leiden",
title="Original Data with Plastic Cells Highlighted",
size=40,
frameon=False,
edges=True,
edges_color="black",
ax=axs[0],
show=False
)
umap1 = pd.DataFrame(ad.obsm['X_umap'], index=ad.obs_names)
axs[0].scatter(
umap1.loc[all_plastic_cells, 0],
umap1.loc[all_plastic_cells, 1],
s=40, facecolors='none', edgecolors='black', linewidth=1,
label=f'Plastic cells (n={len(all_plastic_cells)})'
)
# Panel 2: after plasticity with plastic cells outlined
sc.pl.umap(
plastic_ad,
color="leiden",
title="After Plasticity (Plastic Cells Outlined)",
size=40,
frameon=False,
edges=True,
edges_color="black",
ax=axs[1],
show=False
)
umap2 = pd.DataFrame(plastic_ad.obsm['X_umap'], index=plastic_ad.obs_names)
axs[1].scatter(
umap2.loc[all_plastic_cells, 0],
umap2.loc[all_plastic_cells, 1],
s=40, facecolors='none', edgecolors='black', linewidth=1,
label=f'Plastic cells (n={len(all_plastic_cells)})'
)
# Panel 3: distribution of phenotypic change distances
sns.histplot(change_distances, bins=30, kde=True, ax=axs[2], alpha=0.7)
axs[2].set_title("Phenotypic Change Distribution")
axs[2].set_xlabel(f"Change Distance ({latent_space_key})")
axs[2].set_ylabel("Count")
# Add summary statistics to the histogram
mean_change = change_distances.mean()
median_change = np.median(change_distances)
axs[2].axvline(mean_change, color='red', linestyle='--', label=f'Mean: {mean_change:.3f}')
axs[2].axvline(median_change, color='orange', linestyle=':', label=f'Median: {median_change:.3f}')
plt.tight_layout()
if show_plots:
plt.show()
if save_to:
plt.savefig(save_to)
plt.close()
return