Skip to content

arborist

Rank candidate SNV phylogenetic trees using scDNA-seq read counts.

This function evaluates a list of clonal trees using a variational inference scheme and returns an evidence lower bound (ELBO)-based score for each tree. An initial SNV-to-cluster prior is set given an initial SNV clustering, and both cell-to-clone and SNV-to-cluster assignments are optimized under the model.

Parameters:

Name Type Description Default
tree_list list of list of tuple

Candidate phylogenetic trees to be ranked. Each tree is represented as a list of directed edges (parent, child). All trees must contain the same set of clone identifiers.

required
read_counts DataFrame

Data frame with columns ["snv", "cell", "alt", "total"] in any order, containing per-cell read counts for each SNV.

required
snv_clusters DataFrame

Data frame with columns ["snv", "cluster"] in any order, giving an initial hard assignment of SNVs to clusters.

required
alpha float

Per-base sequencing error rate used to compute log-likelihoods for presence/absence of an SNV (default is 0.001).

0.001
max_iter int

Maximum number of coordinate-ascent iterations in the variational inference procedure (default is 10).

10
tolerance float

Convergence threshold on the change in ELBO between iterations (default is 1.0).

1
gamma float

Prior probability mass placed on the initial SNV cluster assignment in q_y (default is 0.7). The remaining mass is spread uniformly over alternative clusters.

0.7
add_normal bool

If True, prepend a normal clone as a new root node to each tree in tree_list (default is False).

False
threads int

Number of threads to use for numba-parallelized computations (default is 10).

10
verbose bool

If True, enable informative logging messages during fitting (default is False).

False

Returns:

Name Type Description
likelihoods dict[int, float]

Mapping from tree index to its ELBO (expected log joint) under the variational posterior.

best_fit TreeFit

TreeFit object for the top-ranked tree, containing the tree, ELBO, posterior cell-to-clone distribution q_z, posterior SNV-to-cluster distribution q_y, and associated index mappings.

all_tree_fits dict[int, TreeFit]

Dictionary mapping each tree index to its corresponding TreeFit object.

Raises:

Type Description
ValueError

If the candidate trees in tree_list do not all share the same set of clone identifiers.

Notes

Only SNVs with at least one variant read across all cells are retained for inference. SNVs with zero total variant counts are dropped from read_counts before fitting.

Examples:

>>> likelihoods, best_fit, all_fits = arborist(
...     tree_list,
...     read_counts,
...     snv_clusters,
...     alpha=0.001,
...     max_iter=25,
...     gamma=0.7,
... )
Source code in src/arborist/arborist.py
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
def arborist(
    tree_list: list,
    read_counts: pd.DataFrame,
    snv_clusters: pd.DataFrame,
    alpha: float = 0.001,
    max_iter: int = 10,
    tolerance: float = 1,
    gamma=0.7,
    add_normal=False,
    threads = 10,
    verbose: bool = False,
) -> tuple:
    """
    Rank candidate SNV phylogenetic trees using scDNA-seq read counts.

    This function evaluates a list of clonal trees using a variational inference
    scheme and returns an evidence lower bound (ELBO)-based score for each tree.
    An initial SNV-to-cluster prior is set given an initial SNV clustering, and both cell-to-clone and
    SNV-to-cluster assignments are optimized under the model.

    Parameters
    ----------
    tree_list : list of list of tuple
        Candidate phylogenetic trees to be ranked. Each tree is represented as
        a list of directed edges ``(parent, child)``. All trees must contain
        the same set of clone identifiers.
    read_counts : pandas.DataFrame
        Data frame with columns ``["snv", "cell", "alt", "total"]`` in any order, 
        containing per-cell read counts for each SNV.
    snv_clusters : pandas.DataFrame
        Data frame with columns ``["snv", "cluster"]`` in any order, giving an
        initial hard assignment of SNVs to clusters.
    alpha : float, optional
        Per-base sequencing error rate used to compute log-likelihoods for
        presence/absence of an SNV (default is ``0.001``).
    max_iter : int, optional
        Maximum number of coordinate-ascent iterations in the variational
        inference procedure (default is ``10``).
    tolerance : float, optional
        Convergence threshold on the change in ELBO between iterations
        (default is ``1.0``).
    gamma : float, optional
        Prior probability mass placed on the initial SNV cluster assignment
        in ``q_y`` (default is ``0.7``). The remaining mass is spread
        uniformly over alternative clusters.
    add_normal : bool, optional
        If ``True``, prepend a normal clone as a new root node to each tree in
        ``tree_list`` (default is ``False``).
    threads : int, optional
        Number of threads to use for numba-parallelized computations
        (default is ``10``).
    verbose : bool, optional
        If ``True``, enable informative logging messages during fitting
        (default is ``False``).

    Returns
    -------
    likelihoods : dict[int, float]
        Mapping from tree index to its ELBO (expected log joint) under the
        variational posterior.
    best_fit : TreeFit
        ``TreeFit`` object for the top-ranked tree, containing the tree,
        ELBO, posterior cell-to-clone distribution ``q_z``, posterior
        SNV-to-cluster distribution ``q_y``, and associated index mappings.
    all_tree_fits : dict[int, TreeFit]
        Dictionary mapping each tree index to its corresponding ``TreeFit``
        object.

    Raises
    ------
    ValueError
        If the candidate trees in ``tree_list`` do not all share the same
        set of clone identifiers.

    Notes
    -----
    Only SNVs with at least one variant read across all cells are retained
    for inference. SNVs with zero total variant counts are dropped from
    ``read_counts`` before fitting.

    Examples
    --------
    >>> likelihoods, best_fit, all_fits = arborist(
    ...     tree_list,
    ...     read_counts,
    ...     snv_clusters,
    ...     alpha=0.001,
    ...     max_iter=25,
    ...     gamma=0.7,
    ... )
    """


    if verbose:
        logger = logging.getLogger()

    numba.set_num_threads(threads)

    if add_normal:
        if verbose:
            logger.info(f"Appending normal clone to candidate trees...")
        tree_list = add_normal_clone(tree_list)

    tree = tree_list[0]

    # assume root is normal
    temp_tree = nx.DiGraph(tree)
    normal = [n for n in temp_tree if temp_tree.in_degree(n) == 0][0]
    clone_set = tree_to_clone_set(tree)
    for tree in tree_list:
        if tree_to_clone_set(tree) != clone_set:
            raise ValueError("All trees must have the same set of clones.")

    clones = list(clone_set)

    clusters = [c for c in clones if c != normal]

    # Filter read_counts to only include cells and SNVs present in the tree

    clones.sort()
    clusters.sort()
    clone_to_idx = {c: i for i, c in enumerate(clones)}
    cluster_to_idx = {c: i for i, c in enumerate(clusters)}


    if verbose:
        logger.info(f"Removing SNVs with 0 variant reads across all cells...")

    alt_sum = read_counts.groupby("snv")["alt"].sum()
    valid_snvs = alt_sum[alt_sum > 0].index


    # valid SNVs are SNVs that have at least one variant read across all cells
    # otherwise we have no signal to place them in the tree

    read_counts = read_counts[read_counts["snv"].isin(valid_snvs)].copy()

    # appends columns log_absent and log_present to read_counts
    if verbose:
        logger.info(f"Caching log-likelihoods for presence/absence...")
    read_counts, cell_to_idx, snv_to_idx = precompute_log_likelihoods(
        read_counts, alpha
    )

    if verbose:
        logger.info(f"Initializing the SNV cluster assignment prior...")
    q_y_init = initialize_q_y(snv_clusters, cluster_to_idx, snv_to_idx, gamma)
    cell_idx, snv_idx, log_like_matrix = build_sparse_input(
        read_counts, cell_to_idx, snv_to_idx
    )

    n_cells = len(cell_to_idx)
    n_snvs = len(snv_to_idx)
    n_clones = len(clones)


    cell_ptr, _, snv_index_cell_sort, log_like_matrix_cell_sort = (
        build_index_pointers(
            cell_idx, snv_idx, n_cells, log_like_matrix=log_like_matrix
        )
    )
    snv_ptr, _ , cell_index_snv_sort, log_like_matrix_snv_sort = (
        build_index_pointers(snv_idx, cell_idx, n_snvs, log_like_matrix=log_like_matrix)
    )

    best_likelihood = -np.inf
    likelihoods = {}

    all_tree_fits = {}

    run = run_variational_inference

    if verbose:
        logger.info(f"Starting Arborist...")
        logger.info(f"Number of candidate trees: {len(tree)}")
        logger.info(f"Number of clones: {n_clones}")
        logger.info(f"Number of SNV clusters: {len(cluster_to_idx)}")
        logger.info(f"Number of cells: {n_cells}")
        logger.info(f"Number of SNVs: {n_snvs}")
    # else:
    #     print("Running Arborist in cell MAP assignment mode....")
    #     run = run_simple_max_likelihood

    for idx, tree in enumerate(tree_list):

        presence = enumerate_presence(tree, clone_to_idx, cluster_to_idx)

        expected_log_like, q_z, q_y = run(
            presence,
            log_like_matrix_cell_sort,
            log_like_matrix_snv_sort,
            cell_idx=cell_index_snv_sort,
            snv_idx=snv_index_cell_sort,
            n_cells=n_cells,
            n_snvs=n_snvs,
            n_clones=n_clones,
            q_y_init=q_y_init,
            cell_ptr=cell_ptr,
            snv_ptr=snv_ptr,
            max_iter=max_iter,
            tolerance=tolerance,
        )

        tfit = TreeFit(
            tree_list[idx],
            idx,
            expected_log_like,
            q_z,
            q_y,
            cell_to_idx,
            snv_to_idx,
            clone_to_idx,
            cluster_to_idx,
        )
        all_tree_fits[idx] = tfit
        if verbose:
            logger.info(f"Tree {idx} fit wtih ELBO: {expected_log_like:.2f}")
        likelihoods[idx] = expected_log_like
        if expected_log_like > best_likelihood:
            best_fit = tfit
            best_likelihood = expected_log_like

    if verbose:
        logger.info(f"Done fitting all {len(tree_list)} candidate trees!")
    return likelihoods, best_fit, all_tree_fits