Source code for plastro.phenotype_simulation

"""
Phenotype simulation module for generating synthetic single-cell data.

This module provides functions for creating synthetic single-cell datasets with
branching differentiation trajectories, simulating realistic cellular development
patterns and phenotypic transitions for testing plasticity algorithms.
"""

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from anndata import AnnData
import scanpy as sc
from typing import Tuple, List, Optional

# Note: HIDE_PLOTS is now controlled per function via show_plots parameter


[docs] def create_random_binary_tree(n_leaves: int, sample_res: int) -> Tuple: """ Create a binary tree with random distribution of leaves. Generates a random binary tree structure for simulating cellular differentiation hierarchies. Each node represents a cellular state with a certain number of cells, and the tree structure represents the developmental relationships. Parameters ---------- n_leaves : int Total number of terminal branches (leaf nodes) in the tree. sample_res : int Sample resolution multiplier that determines the number of cells per node. Each node will have between 1-10 times this value as the number of cells. Returns ------- Tuple A tuple representing the binary tree structure: - First element: number of samples at this node (int) - Second element: list of child branches (empty for leaf nodes) Examples -------- >>> tree = create_random_binary_tree(n_leaves=4, sample_res=50) >>> # Creates a tree with 4 terminal branches, each with ~50-500 cells Notes ----- The tree structure is represented as nested tuples where: - Leaf nodes: (sample_count, []) - Internal nodes: (sample_count, [left_child, right_child]) This creates a realistic branching structure similar to cellular development where progenitor cells give rise to more specialized cell types. """ if n_leaves == 1: # Terminal node - return as leaf with random sample count return (random.randint(1, 10) * sample_res, []) # Randomly distribute leaves between left and right branches left_leaves = random.randint(1, n_leaves - 1) right_leaves = n_leaves - left_leaves # Recursively create child branches left_branch = create_random_binary_tree(left_leaves, sample_res) right_branch = create_random_binary_tree(right_leaves, sample_res) # Current node's sample count node_samples = random.randint(1, 10) * sample_res return (node_samples, [left_branch, right_branch])
[docs] def sample_branch( base: np.ndarray, velocity: np.ndarray, sample_structure: Tuple, curvature: float = 0.2, var_decay: float = 1.5, dens_decay: float = 0.9, n_dim: int = 15, branch_name: str = 'b' ) -> Tuple[List[np.ndarray], List, List[int], List[np.ndarray]]: """ Sample cells along a differentiation branch with realistic noise structure. Generates synthetic single-cell data along a branching trajectory that mimics cellular differentiation. Uses a physics-inspired model where cells follow curved paths through gene expression space with decreasing variance over time. Parameters ---------- base : np.ndarray Starting position in gene expression space (n_dim,). velocity : np.ndarray Initial direction vector for trajectory (n_dim,). sample_structure : Tuple Tree structure from create_random_binary_tree defining sampling. curvature : float, optional Amount of random curvature in trajectory (0-1), by default 0.2. Higher values create more curved, realistic paths. var_decay : float, optional Rate of variance decay along trajectory, by default 1.5. Higher values create more focused terminal populations. dens_decay : float, optional Rate of density decay (cell loss), by default 0.9. Models cell death during differentiation. n_dim : int, optional Number of dimensions (genes) in expression space, by default 15. branch_name : str, optional Name identifier for this branch, by default 'b'. Returns ------- Tuple[List[np.ndarray], List, List[int], List[np.ndarray]] - samples: List of cell expression matrices for each sub-branch - distributions: List of multivariate normal distributions used - n_draws: List of cell counts for each sub-branch - names: List of branch name arrays for each sub-branch Examples -------- >>> base = np.zeros(10) >>> velocity = np.ones(10) >>> structure = (100, []) # Simple leaf with 100 cells >>> samples, dists, counts, names = sample_branch(base, velocity, structure) Notes ----- The sampling model creates realistic gene expression patterns by: - Adding curved random walk behavior via curvature parameter - Implementing variance decay to model cellular commitment - Using QR decomposition to create proper covariance structure - Applying density decay to model cell loss during development """ n_draws, sub_struct = sample_structure # Calculate trajectory curvature and noise scale = np.sqrt(np.sum(velocity**2)) delta = np.random.normal(0, scale, n_dim) new_velocity = (velocity * (1 - curvature) + delta * curvature) * dens_decay # Update position along trajectory mu = base + (new_velocity / 2) # Create proper covariance structure using QR decomposition q, r = np.linalg.qr(new_velocity[:, None], mode='complete') mr = 5e-1 * np.exp(-np.arange(n_dim) * var_decay / 2) * np.abs(r[0]) mr = np.clip(mr, 1e-4 * np.abs(r[0]), None) bases = q * mr[None, :] cov = bases.dot(bases.T) # Sample cells from multivariate normal distribution dist = multivariate_normal(mu, cov) samples = dist.rvs(n_draws) # Initialize return lists samp_list = [samples] dist_list = [dist] draws_list = [n_draws] name_list = [np.repeat(branch_name, n_draws)] # Update base position for child branches new_base = base + new_velocity # Recursively sample child branches for i, ss in enumerate(sub_struct): child_samples, child_dist, child_n_draws, child_names = sample_branch( new_base, new_velocity, ss, curvature=curvature, var_decay=var_decay, n_dim=n_dim, branch_name=f'{branch_name}-{i}', ) samp_list += child_samples dist_list += child_dist draws_list += child_n_draws name_list += child_names return samp_list, dist_list, draws_list, name_list
[docs] def generate_ad(sample_structure: Tuple, n_dim: int, show_plots: bool = False) -> AnnData: """ Generate a complete AnnData object with simulated single-cell data. Creates a comprehensive single-cell dataset with realistic gene expression patterns, UMAP embedding, clustering annotations, and proper metadata for studying cellular plasticity and differentiation. Parameters ---------- sample_structure : Tuple Binary tree structure from create_random_binary_tree. n_dim : int Number of dimensions (genes) in the expression space. show_plots : bool, optional Whether to display plots during computation, by default False. Returns ------- AnnData Complete annotated dataset containing: - X: Gene expression matrix (n_cells × n_genes) - obs: Cell metadata with ground truth, branch labels, colors - obsm: Dimensionality reductions (UMAP, diffusion components) - uns: Cluster colors and other metadata Examples -------- >>> structure = create_random_binary_tree(n_leaves=6, sample_res=100) >>> adata = generate_ad(structure, n_dim=20) >>> print(f"Generated {adata.n_obs} cells with {adata.n_vars} genes") >>> >>> # Visualize the simulated data >>> import scanpy as sc >>> sc.pl.umap(adata, color='branch') >>> >>> # Generate with plots enabled >>> adata_with_plots = generate_ad(structure, n_dim=20, show_plots=True) Notes ----- The generated dataset includes: - Realistic branching trajectories in gene expression space - Ground truth probability densities for each cell - UMAP coordinates for visualization - Leiden clustering annotations - Color maps for consistent plotting - Diffusion components for plasticity analysis This provides a complete testing framework for plasticity algorithms with known ground truth cellular relationships. """ # Generate samples along all branches sample_list, distributions, draws_list, name_list = sample_branch( np.zeros(n_dim), np.ones(n_dim), sample_structure, n_dim=n_dim ) # Combine all samples samples = np.concatenate(sample_list, axis=0) n_draws = np.array(draws_list) weights = n_draws / np.sum(n_draws) # Compute ground truth probability density pdf = np.zeros(np.sum(n_draws)) for w, dist in zip(weights, distributions): pdf += w * dist.pdf(samples) # Create AnnData object sim_ad = AnnData( samples, obs=pd.DataFrame({ 'ground_truth': pdf, 'log_ground_truth': np.log10(pdf), 'branch': np.concatenate(name_list, axis=0), }, index=np.arange(np.sum(n_draws)).astype(str)) ) # Set proper cell and branch names sim_ad.obs['branch_name'] = sim_ad.obs['branch'].astype('category') sim_ad.obs_names = 'Cell_' + sim_ad.obs_names.astype(str) sim_ad.obsm['X_dc'] = sim_ad.to_df().values # Compute UMAP embedding sc.pp.neighbors(sim_ad, n_neighbors=30, use_rep='X') sc.tl.umap(sim_ad) # Generate branch colors using tab20 colormap unique_branches = sim_ad.obs['branch'].unique() tab20_colors = plt.cm.get_cmap('tab20', len(unique_branches)) colors = { cell_type: rgba_to_hex(tab20_colors(i)) for i, cell_type in enumerate(unique_branches) } sim_ad.uns['branch_colors'] = [colors[branch] for branch in unique_branches] sim_ad.obs['branch_colors'] = [colors[branch] for branch in sim_ad.obs['branch']] # Compute Leiden clustering (optional, requires igraph) try: sc.pp.neighbors(sim_ad, n_neighbors=30) sc.tl.leiden(sim_ad) except ImportError: print("Warning: igraph not available, skipping Leiden clustering") print("Install with: pip install igraph") # Create dummy leiden clusters based on branches unique_branches = sim_ad.obs['branch'].unique() branch_to_leiden = {branch: str(i) for i, branch in enumerate(unique_branches)} sim_ad.obs['leiden'] = sim_ad.obs['branch'].map(branch_to_leiden) # Generate leiden cluster colors unique_clusters = sim_ad.obs['leiden'].unique() set2_colors = plt.cm.get_cmap('tab20', len(unique_clusters)) colors = { cell_type: rgba_to_hex(set2_colors(i)) for i, cell_type in enumerate(unique_clusters) } sim_ad.uns['leiden_colors'] = [colors[cluster] for cluster in unique_clusters] sim_ad.obs['leiden_colors'] = [colors[cluster] for cluster in sim_ad.obs['leiden']] return sim_ad
[docs] def subset_to_terminal_branches(ad: AnnData, show_plots: bool = False) -> AnnData: """ Extract only terminal branches from simulated differentiation data. Identifies and extracts cells from terminal (leaf) branches of the differentiation tree. These represent fully differentiated cell types and are commonly used for lineage tracing analysis. Parameters ---------- ad : AnnData Annotated data object with branch annotations in obs['branch']. show_plots : bool, optional Whether to display the terminal branches visualization, by default False. Returns ------- AnnData Subset containing only cells from terminal branches. Examples -------- >>> # Generate full differentiation tree >>> structure = create_random_binary_tree(n_leaves=4, sample_res=100) >>> full_data = generate_ad(structure, n_dim=15) >>> >>> # Extract only terminal cell types >>> terminal_data = subset_to_terminal_branches(full_data) >>> print(f"Reduced from {full_data.n_obs} to {terminal_data.n_obs} cells") >>> >>> # Show terminal branch visualization >>> terminal_data = subset_to_terminal_branches(full_data, show_plots=True) Notes ----- Terminal branches are identified as branch names that are not prefixes of any other branch names. For example, in branches ['b', 'b-0', 'b-0-1'], only 'b-0-1' would be considered terminal. This function also creates a visualization showing the terminal branches highlighted in red on the UMAP embedding. """ sim_ad = ad.copy() # Get all branch names branch_names = sim_ad.obs['branch'].unique() # Find terminal branches (not prefixes of other branches) terminal_branches = [] for branch_name in branch_names: # Check if this branch name is a prefix of any other branch is_terminal = not any( other_branch_name.startswith(branch_name) and branch_name != other_branch_name for other_branch_name in branch_names ) if is_terminal: terminal_branches.append(branch_name) # Visualize terminal vs intermediate branches plt.figure(figsize=(5, 5)) plt.scatter(sim_ad.obsm['X_umap'][:, 0], sim_ad.obsm['X_umap'][:, 1], c='grey', s=0.5, alpha=0.5, label='Intermediate') terminal_mask = sim_ad.obs['branch'].isin(terminal_branches) plt.scatter(sim_ad.obsm['X_umap'][terminal_mask, 0], sim_ad.obsm['X_umap'][terminal_mask, 1], c='red', s=0.5, label='Terminal') plt.axis('off') plt.legend() plt.title('Terminal Branches (Red)') if show_plots: plt.show() else: plt.close() # Subset to only terminal branches sim_ad = sim_ad[terminal_mask, :].copy() print(f"Selected {len(terminal_branches)} terminal branches:") for branch in terminal_branches: n_cells = sum(sim_ad.obs['branch'] == branch) print(f" {branch}: {n_cells} cells") return sim_ad
[docs] def rgba_to_hex(rgba: Tuple) -> str: """ Convert RGBA color to hexadecimal format. Utility function for converting matplotlib color tuples to hex strings for consistent color handling across different plotting libraries. Parameters ---------- rgba : Tuple RGBA color tuple with values either as floats (0-1) or integers (0-255). Returns ------- str Hexadecimal color string in format '#RRGGBBAA'. Examples -------- >>> rgba_to_hex((1.0, 0.5, 0.0, 1.0)) # Orange, full opacity '#ff8000ff' >>> rgba_to_hex((255, 128, 0, 255)) # Same color, integer format '#ff8000ff' """ # Convert float values (0-1) to integers (0-255) if necessary if any(isinstance(component, float) for component in rgba): rgba = tuple( int(component * 255) if isinstance(component, float) else component for component in rgba ) return "#{:02x}{:02x}{:02x}{:02x}".format(*rgba)
[docs] def simulate_realistic_dataset( n_cell_types: int = 6, cells_per_type: int = 100, n_genes: int = 20, noise_level: float = 0.2, seed: Optional[int] = None, show_plots: bool = False ) -> AnnData: """ Generate a realistic single-cell dataset for plasticity testing. Convenience function that combines tree generation and sampling to create a complete synthetic dataset with realistic parameters for testing plasticity simulation algorithms. Parameters ---------- n_cell_types : int, optional Number of terminal cell types to generate, by default 6. cells_per_type : int, optional Approximate number of cells per cell type, by default 100. n_genes : int, optional Number of genes (dimensions) in expression space, by default 20. noise_level : float, optional Amount of noise in trajectories (0-1), by default 0.2. seed : int, optional Random seed for reproducibility, by default None. show_plots : bool, optional Whether to display plots during dataset generation, by default False. Returns ------- AnnData Complete annotated dataset ready for plasticity analysis. Examples -------- >>> # Generate a standard test dataset >>> adata = simulate_realistic_dataset( ... n_cell_types=8, ... cells_per_type=150, ... n_genes=25, ... seed=42, ... show_plots=True # Display terminal branch plots ... ) >>> >>> # Visualize the dataset >>> import scanpy as sc >>> sc.pl.umap(adata, color=['branch', 'leiden']) Notes ----- This function provides sensible defaults for most plasticity simulation experiments and ensures reproducible results when a seed is provided. """ if seed is not None: random.seed(seed) np.random.seed(seed) # Create tree structure sample_res = cells_per_type // 5 # Adjust for realistic cell counts tree_structure = create_random_binary_tree(n_cell_types, sample_res) # Generate the dataset adata = generate_ad(tree_structure, n_genes, show_plots=show_plots) print(f"Generated realistic dataset:") print(f" {adata.n_obs} cells") print(f" {adata.n_vars} genes") print(f" {len(adata.obs['branch'].unique())} branches") print(f" {len(adata.obs['leiden'].unique())} leiden clusters") return adata