diff --git a/.gitignore b/.gitignore index 4cb7d7a2..1c3ff383 100755 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,5 @@ build stdout.log notebooks/.ipynb_checkpoints cassiopeia/tools/branch_length_estimator/_iid_exponential_bayesian.cpp -docs/api/reference/** \ No newline at end of file +docs/api/reference/** +cassiopeia/config.ini diff --git a/README.md b/README.md index e0cab614..efb1b77a 100755 --- a/README.md +++ b/README.md @@ -45,7 +45,9 @@ For developers: * Run the command ``gurobi.sh`` from a terminal window * From the Gurobi installation directory (where there is a setup.py file), use ``python setup.py install --user`` -4. Install Cassiopeia by first changing into the Cassiopeia directory and then `pip3 install .`. To install dev and docs requirements, you can run `pip3 install .[dev,docs]`. +4. [Optional] To use fast versions of Neighbor-Joining and UPGMA, install [CCPhylo](https://bitbucket.org/genomicepidemiology/ccphylo/src/master/) then set ccphylo_path in the config.ini file in the cassiopeia directory. + +5. Install Cassiopeia by first changing into the Cassiopeia directory and then `pip3 install .`. To install dev and docs requirements, you can run `pip3 install .[dev,docs]`. To verify that it installed correctly, try running our tests with `pytest`. diff --git a/cassiopeia/config.ini b/cassiopeia/config.ini new file mode 100755 index 00000000..cc497800 --- /dev/null +++ b/cassiopeia/config.ini @@ -0,0 +1,2 @@ +[Paths] +ccphylo_path = /path/to/ccphylo/ccphylo diff --git a/cassiopeia/solver/DistanceSolver.py b/cassiopeia/solver/DistanceSolver.py index cfcdd099..9fdbaa30 100644 --- a/cassiopeia/solver/DistanceSolver.py +++ b/cassiopeia/solver/DistanceSolver.py @@ -3,19 +3,28 @@ the inference procedures that inherit from this method will need to implement methods for selecting "cherries" and updating the dissimilarity map. Methods that will inherit from this class by default are Neighbor-Joining and UPGMA. -There may be other subclasses of this. +There may be other subclasses of this. Currently also implements a method for +solving trees with CCPhylo but this will be moved with switch to compositional +framework. """ +import os + import abc +import subprocess +import tempfile from typing import Callable, Dict, List, Optional, Tuple +import configparser +import ete3 import networkx as nx import numpy as np import pandas as pd from cassiopeia.data import CassiopeiaTree from cassiopeia.mixins import DistanceSolverError -from cassiopeia.solver import CassiopeiaSolver, solver_utilities - +from cassiopeia.solver import (CassiopeiaSolver, + solver_utilities) +from cassiopeia.data import utilities class DistanceSolver(CassiopeiaSolver.CassiopeiaSolver): """ @@ -74,6 +83,9 @@ def __init__( self.dissimilarity_function = dissimilarity_function self.add_root = add_root + + if "ccphylo" in self._implementation: + self._setup_ccphylo() def get_dissimilarity_map( self, @@ -108,7 +120,6 @@ def get_dissimilarity_map( return dissimilarity_map - def solve( self, cassiopeia_tree: CassiopeiaTree, @@ -135,6 +146,24 @@ def solve( removes artifacts caused by arbitrarily resolving polytomies. logfile: File location to log output. Not currently used. """ + + if self._implementation == "ccphylo_dnj": + self._ccphylo_solve(cassiopeia_tree,layer, + collapse_mutationless_edges,logfile,method="dnj") + return + elif self._implementation == "ccphylo_nj": + self._ccphylo_solve(cassiopeia_tree,layer, + collapse_mutationless_edges,logfile,method="nj") + return + elif self._implementation == "ccphylo_hnj": + self._ccphylo_solve(cassiopeia_tree,layer, + collapse_mutationless_edges,logfile,method="hnj") + return + elif self._implementation == "ccphylo_upgma": + self._ccphylo_solve(cassiopeia_tree,layer, + collapse_mutationless_edges,logfile,method="upgma") + return + node_name_generator = solver_utilities.node_name_generator() dissimilarity_map = self.get_dissimilarity_map(cassiopeia_tree, layer) @@ -197,6 +226,121 @@ def solve( infer_ancestral_characters=True ) + def _ccphylo_solve( + self, + cassiopeia_tree: CassiopeiaTree, + layer: Optional[str] = None, + collapse_mutationless_edges: bool = False, + logfile: str = "stdout.log", + method: str = "dnj" + ) -> None: + """Solves a tree using fast distance-based algorithms implemented by + CCPhylo. To call this method the CCPhlyo package must be installed + and the ccphylo_path must be set in the config file. The method + attribute specifies which algorithm to use. The function will update the + `tree`. + + Args: + cassiopeia_tree: CassiopeiaTree object to be populated + layer: Layer storing the character matrix for solving. If None, the + default character matrix is used in the CassiopeiaTree. + collapse_mutationless_edges: Indicates if the final reconstructed + tree should collapse mutationless edges based on internal states + inferred by Camin-Sokal parsimony. In scoring accuracy, this + removes artifacts caused by arbitrarily resolving polytomies. + logfile: File location to log output. Not currently used. + """ + + dissimilarity_map = self.get_dissimilarity_map(cassiopeia_tree, layer) + + with tempfile.TemporaryDirectory() as temp_dir: + + # save dissimilarity map as phylip file + dis_path = os.path.join(temp_dir, "dist.phylip") + tree_path = os.path.join(temp_dir, "tree.nwk") + solver_utilities.save_dissimilarity_as_phylip(dissimilarity_map, + dis_path) + + # run ccphylo + command = (f"{self._ccphylo_path} tree -i {dis_path} -o " + f"{tree_path} -m {method}") + + process = subprocess.Popen(command, shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + T = ete3.Tree(tree_path, format=1) + + # remove temporary files + os.remove(dis_path) + os.remove(tree_path) + + # Covert to networkx + tree = utilities.ete3_to_networkx(T).to_undirected() + + # find last split + midpoint = T.get_midpoint_outgroup() + root = T.get_tree_root() + if midpoint in root.children: + last_split = [root.name,midpoint.name] + else: + last_split = [root.name,root.children[0].name] + tree.remove_edge(last_split[0],last_split[1]) + + # root tree + tree = self.root_tree(tree,cassiopeia_tree.root_sample_name,last_split) + + # remove root from character matrix before populating tree + if ( + cassiopeia_tree.root_sample_name + in cassiopeia_tree.character_matrix.index + ): + cassiopeia_tree.character_matrix = ( + cassiopeia_tree.character_matrix.drop( + index=cassiopeia_tree.root_sample_name + ) + ) + + # populate tree + cassiopeia_tree.populate_tree(tree,layer=layer) + cassiopeia_tree.collapse_unifurcations() + + # collapse mutationless edges + if collapse_mutationless_edges: + cassiopeia_tree.collapse_mutationless_edges( + infer_ancestral_characters=True + ) + + def _setup_ccphylo(self) -> None: + """Sets up the ccphylo solver by getting the ccphylo_path from the + config file and checking that it is valid. + """ + + # get ccphylo path + config = configparser.ConfigParser() + config.read(os.path.join(os.path.dirname(__file__),"..","config.ini")) + self._ccphylo_path = config.get("Paths","ccphylo_path") + + #check that ccphylo_path is valid + if not os.path.exists(self._ccphylo_path): + raise DistanceSolverError( + f"ccphylo_path {self._ccphylo_path} does not exist. To use fast " + "versions of Neighbor-Joining and UPGMA please install CCPhylo " + "(https://bitbucket.org/genomicepidemiology/ccphylo/src/master/)" + "set the ccphylo_path in the config.ini file then reinstall " + "Cassiopeia." + ) + + #check that ccphylo_path is executable + if not os.access(self._ccphylo_path, os.X_OK): + raise DistanceSolverError( + f"ccphylo_path {self._ccphylo_path} is not executable. To use " + "fast versions of Neighbor-Joining and UPGMA please install CCPhylo" + " (https://bitbucket.org/genomicepidemiology/ccphylo/src/master/) " + "set the ccphylo_path in the config.ini file then reinstall " + "Cassiopeia." + ) + def setup_dissimilarity_map( self, cassiopeia_tree: CassiopeiaTree, layer: Optional[str] = None ) -> None: @@ -228,8 +372,8 @@ def setup_dissimilarity_map( else: raise DistanceSolverError( - "Please specify an explicit root sample in the Cassiopeia Tree" - " or specify the solver to add an implicit root" + "Please specify an explicit root sample in the Cassiopeia" + " Tree or specify the solver to add an implicit root" ) if cassiopeia_tree.get_dissimilarity_map() is None: diff --git a/cassiopeia/solver/NeighborJoiningSolver.py b/cassiopeia/solver/NeighborJoiningSolver.py index 848d9a7d..274be39e 100755 --- a/cassiopeia/solver/NeighborJoiningSolver.py +++ b/cassiopeia/solver/NeighborJoiningSolver.py @@ -1,5 +1,5 @@ """ -This file stores a subclass of DistanceSolver, NeighborJoining. The +This file stores a subclass of DistanceSolver, NeighborJoiningSolver. The inference procedure is the Neighbor-Joining algorithm proposed by Saitou and Nei (1987) that iteratively joins together samples that minimize the Q-criterion on the dissimilarity map. @@ -13,13 +13,13 @@ import pandas as pd from cassiopeia.data import CassiopeiaTree +from cassiopeia.mixins import DistanceSolverError from cassiopeia.solver import ( DistanceSolver, dissimilarity_functions, solver_utilities, ) - class NeighborJoiningSolver(DistanceSolver.DistanceSolver): """ Neighbor-Joining class for Cassiopeia. @@ -27,7 +27,8 @@ class NeighborJoiningSolver(DistanceSolver.DistanceSolver): Implements the Neighbor-Joining algorithm described by Saitou and Nei (1987) as a derived class of DistanceSolver. This class inherits the generic `solve` method, but implements its own procedure for finding cherries by - minimizing the Q-criterion between samples. + minimizing the Q-criterion between samples. If fast is set to True, + a fast NJ implementation of is used. Args: dissimilarity_function: A function by which to compute the dissimilarity @@ -43,6 +44,15 @@ class NeighborJoiningSolver(DistanceSolver.DistanceSolver): "inverse": Transforms each probability p by taking 1/p "square_root_inverse": Transforms each probability by the the square root of 1/p + fast: Whether to use a fast implementation of Neighbor-Joining. + implementation: Which fast implementation to use. Options are: + "ccphylo_dnj": CCPhylo implementation the Dynamic Neighbor-Joining + algorithm described by Clausen (2023). Solution in guaranteed + to be exact. + "ccphylo_hnj": CCPhylo implementation of the Heuristic + Neighbor-Joining algorithm described by Clausen (2023). + Solution is not guaranteed to be exact. + "ccphylo_nj": CCPhylo implementation of the Neighbor-Joining. Attributes: dissimilarity_function: Function used to compute dissimilarity between @@ -62,8 +72,21 @@ def __init__( ] = dissimilarity_functions.weighted_hamming_distance, add_root: bool = False, prior_transformation: str = "negative_log", + fast: bool = False, + implementation: str = "ccphylo_dnj", ): + if fast: + if implementation in ["ccphylo_dnj", "ccphylo_hnj", "ccphylo_nj"]: + self._implementation = implementation + else: + raise DistanceSolverError( + "Invalid fast implementation of Neighbor-Joining. Options " + "are: 'ccphylo_dnj', 'ccphylo_hnj', 'ccphylo_nj'" + ) + else: + self._implementation = "generic_nj" + super().__init__( dissimilarity_function=dissimilarity_function, add_root=add_root, diff --git a/cassiopeia/solver/SpectralNeighborJoiningSolver.py b/cassiopeia/solver/SpectralNeighborJoiningSolver.py index ec1c8ca0..acc6d431 100644 --- a/cassiopeia/solver/SpectralNeighborJoiningSolver.py +++ b/cassiopeia/solver/SpectralNeighborJoiningSolver.py @@ -71,6 +71,8 @@ def __init__( add_root: bool = False, prior_transformation: str = "negative_log", ): + self._implementation = "generic_spectral_nj" + super().__init__( dissimilarity_function=similarity_function, add_root=add_root, @@ -80,6 +82,7 @@ def __init__( self._similarity_map = None self._lambda_indices = None + def get_dissimilarity_map( self, cassiopeia_tree: CassiopeiaTree, layer: Optional[str] = None ) -> pd.DataFrame: diff --git a/cassiopeia/solver/UPGMASolver.py b/cassiopeia/solver/UPGMASolver.py index ba535de2..cc87b648 100644 --- a/cassiopeia/solver/UPGMASolver.py +++ b/cassiopeia/solver/UPGMASolver.py @@ -13,6 +13,7 @@ import pandas as pd from cassiopeia.data import CassiopeiaTree +from cassiopeia.mixins import DistanceSolverError from cassiopeia.solver import DistanceSolver, dissimilarity_functions @@ -26,7 +27,7 @@ class UPGMASolver(DistanceSolver.DistanceSolver): dissimilarity between samples. After joining nodes, the dissimilarities are updated by averaging the distances of elements in the new cluster with each existing node. Produces a rooted tree that is assumed to be - ultrametric. + ultrametric. If fast is set to True, a fast UPGMA implementation of is used. Args: dissimilarity_function: A function by which to compute the dissimilarity @@ -38,6 +39,9 @@ class UPGMASolver(DistanceSolver.DistanceSolver): "inverse": Transforms each probability p by taking 1/p "square_root_inverse": Transforms each probability by the the square root of 1/p + fast: Whether to use a fast implementation of UPGMA. + implementation: Which fast implementation to use. Options are: + "ccphylo_upgma": Uses the fast UPGMA implementation from CCPhylo. Attributes: dissimilarity_function: Function used to compute dissimilarity between samples. @@ -54,8 +58,21 @@ def __init__( ] ] = dissimilarity_functions.weighted_hamming_distance, prior_transformation: str = "negative_log", + fast: bool = False, + implementation: str = "ccphylo_upgma", ): + if fast: + if implementation in ["ccphylo_upgma"]: + self._implementation = implementation + else: + raise DistanceSolverError( + "Invalid fast implementation of UPGMA. Options are: " + "'ccphylo_upgma'" + ) + else: + self._implementation = "generic_upgma" + super().__init__( dissimilarity_function=dissimilarity_function, add_root=True, diff --git a/cassiopeia/solver/solver_utilities.py b/cassiopeia/solver/solver_utilities.py index c0fd7f51..f79036de 100755 --- a/cassiopeia/solver/solver_utilities.py +++ b/cassiopeia/solver/solver_utilities.py @@ -7,6 +7,7 @@ import ete3 from hashlib import blake2b import numpy as np +import pandas as pd import time from cassiopeia.mixins import PriorTransformationError @@ -122,3 +123,24 @@ def convert_sample_names_to_indices( name_to_index = dict(zip(names, range(len(names)))) return list(map(lambda x: name_to_index[x], samples)) + +def save_dissimilarity_as_phylip( + dissimilarity_map: pd.DataFrame, path: str + ) -> None: + """Saves a dissimilarity map as a phylip file. + + Args: + dissimilarity_map: A dissimilarity map + path: The path to save the phylip file + + Returns: + None + """ + dissimilarity_np = dissimilarity_map.to_numpy() + n = dissimilarity_np.shape[0] + with open(path, "w") as f: + f.write("{}\n".format(n)) + for i in range(n): + row = dissimilarity_np[i, :i+1] + formatted_values = '\t'.join(map('{:.4f}'.format, row)) + f.write("{}\t{}\n".format(dissimilarity_map.index[i], formatted_values)) diff --git a/docs/notebooks b/docs/notebooks index edb8f02d..8f9a5b2e 120000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -../notebooks/ \ No newline at end of file +../notebooks \ No newline at end of file diff --git a/test/solver_tests/ccphylo_solver_test.py b/test/solver_tests/ccphylo_solver_test.py new file mode 100755 index 00000000..759231de --- /dev/null +++ b/test/solver_tests/ccphylo_solver_test.py @@ -0,0 +1,257 @@ +""" +Test the ccphylo solver implementations against the standard NJ and UPGMA +""" +import unittest +from typing import Dict, Optional +from unittest import mock + +import os + +import configparser +import itertools +import networkx as nx +import numpy as np +import pandas as pd + +import cassiopeia as cas + + +def find_triplet_structure(triplet, T): + a, b, c = triplet[0], triplet[1], triplet[2] + a_ancestors = [node for node in nx.ancestors(T, a)] + b_ancestors = [node for node in nx.ancestors(T, b)] + c_ancestors = [node for node in nx.ancestors(T, c)] + ab_common = len(set(a_ancestors) & set(b_ancestors)) + ac_common = len(set(a_ancestors) & set(c_ancestors)) + bc_common = len(set(b_ancestors) & set(c_ancestors)) + structure = "-" + if ab_common > bc_common and ab_common > ac_common: + structure = "ab" + elif ac_common > bc_common and ac_common > ab_common: + structure = "ac" + elif bc_common > ab_common and bc_common > ac_common: + structure = "bc" + return structure + +# specify dissimilarity function for solvers to use +def delta_fn( + x: np.array, + y: np.array, + missing_state: int, + priors: Optional[Dict[int, Dict[int, float]]], +): + d = 0 + for i in range(len(x)): + if x[i] != y[i]: + d += 1 + return d + +# only run test if ccphylo_path is specified in config.ini +config = configparser.ConfigParser() +config.read(os.path.join(os.path.dirname(__file__), + "..","..","cassiopeia","config.ini")) +CCPHYLO_CONFIGURED = (config.get("Paths","ccphylo_path") != + "/path/to/ccphylo/ccphylo") +print(CCPHYLO_CONFIGURED) + + +class TestCCPhyloSolver(unittest.TestCase): + @unittest.skipUnless(CCPHYLO_CONFIGURED, "CCPhylo not configured.") + def setUp(self): + + # --------------------- General NJ --------------------- + cm = pd.DataFrame.from_dict( + { + "a": [0, 1, 2, 1, 0, 0, 2, 0, 0, 0], + "b": [1, 1, 2, 1, 0, 0, 2, 0, 0, 0], + "c": [2, 2, 2, 1, 0, 0, 2, 0, 0, 0], + "d": [1, 1, 1, 1, 0, 0, 2, 0, 0, 0], + "e": [0, 0, 0, 0, 1, 2, 1, 0, 2, 0], + "f": [0, 0, 0, 0, 2, 2, 1, 0, 2, 0], + "g": [0, 2, 0, 0, 1, 1, 1, 0, 2, 0], + "h": [0, 2, 0, 0, 1, 0, 0, 1, 2, 1], + "i": [1, 2, 0, 0, 1, 0, 0, 2, 2, 1], + "j": [1, 2, 0, 0, 1, 0, 0, 1, 1, 1], + }, + orient="index", + columns=["x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", + "x9", "x10"], + ) + + self.cm = cm + self.basic_tree = cas.data.CassiopeiaTree( + character_matrix=cm + ) + + self.nj_solver = cas.solver.NeighborJoiningSolver( + add_root=True,fast=False) + self.ccphylo_nj_solver = cas.solver.NeighborJoiningSolver( + add_root=True,fast=True,implementation="ccphylo_nj") + self.ccphylo_dnj_solver = cas.solver.NeighborJoiningSolver( + add_root=True,fast = True, implementation="ccphylo_dnj") + self.ccphylo_hnj_solver = cas.solver.NeighborJoiningSolver( + add_root=True, fast = True, implementation="ccphylo_hnj") + + self.ccphylo_upgma_solver = cas.solver.UPGMASolver(fast=True) + self.upgma_solver = cas.solver.UPGMASolver(fast=False) + + + # ------------- CM with Duplictes ----------------------- + duplicates_cm = pd.DataFrame.from_dict( + { + "a": [1, 1, 0], + "b": [1, 2, 0], + "c": [1, 2, 1], + "d": [2, 0, 0], + "e": [2, 0, 2], + "f": [2, 0, 2], + }, + orient="index", + columns=["x1", "x2", "x3"], + ) + + self.duplicate_tree = cas.data.CassiopeiaTree( + character_matrix=duplicates_cm + ) + + @unittest.skipUnless(CCPHYLO_CONFIGURED, "CCPhylo not configured.") + def test_ccphylo_invalid_input(self): + with self.assertRaises(cas.solver.DistanceSolver.DistanceSolverError): + nothing_solver = cas.solver.NeighborJoiningSolver(fast = True, + implementation="invalid") + + with self.assertRaises(cas.solver.DistanceSolver.DistanceSolverError): + nothing_solver = cas.solver.UPGMASolver(fast = True, + implementation="invalid") + + @unittest.skipUnless(CCPHYLO_CONFIGURED, "CCPhylo not configured.") + def test_ccphylo_nj_solver(self): + # NJ Solver + nj_tree = self.basic_tree.copy() + self.nj_solver.solve(nj_tree) + + # CCPhylo Fast NJ Solver + ccphylo_nj_tree = self.basic_tree.copy() + self.ccphylo_nj_solver.solve(ccphylo_nj_tree) + + # test for expected number of edges + self.assertEqual(len(nj_tree.edges), len(ccphylo_nj_tree.edges)) + + triplets = itertools.combinations(["a", "c", "d", "e"], 3) + for triplet in triplets: + expected_triplet = find_triplet_structure(triplet, + nj_tree.get_tree_topology()) + observed_triplet = find_triplet_structure(triplet, + ccphylo_nj_tree.get_tree_topology()) + self.assertEqual(expected_triplet, observed_triplet) + + @unittest.skipUnless(CCPHYLO_CONFIGURED, "CCPhylo not configured.") + def test_ccphylo_dnj_solver(self): + # NJ Solver + nj_tree = self.basic_tree.copy() + self.nj_solver.solve(nj_tree) + + # CCPhylo DNJ Solver + dnj_tree = self.basic_tree.copy() + self.ccphylo_dnj_solver.solve(dnj_tree) + + # test for expected number of edges + self.assertEqual(len(nj_tree.edges), len(dnj_tree.edges)) + + triplets = itertools.combinations(["a", "c", "d", "e"], 3) + for triplet in triplets: + expected_triplet = find_triplet_structure(triplet, + nj_tree.get_tree_topology()) + observed_triplet = find_triplet_structure(triplet, + dnj_tree.get_tree_topology()) + self.assertEqual(expected_triplet, observed_triplet) + + @unittest.skipUnless(CCPHYLO_CONFIGURED, "CCPhylo not configured.") + def test_ccphylo_hnj_solver(self): + # NJ Solver + nj_tree = self.basic_tree.copy() + self.nj_solver.solve(nj_tree) + + # CCPhylo HNJ Solver + hnj_tree = self.basic_tree.copy() + self.ccphylo_hnj_solver.solve(hnj_tree) + + # test for expected number of edges + self.assertEqual(len(nj_tree.edges), len(hnj_tree.edges)) + + + triplets = itertools.combinations(["a", "c", "d", "e"], 3) + for triplet in triplets: + expected_triplet = find_triplet_structure(triplet, + nj_tree.get_tree_topology()) + observed_triplet = find_triplet_structure(triplet, + hnj_tree.get_tree_topology()) + self.assertEqual(expected_triplet, observed_triplet) + + @unittest.skipUnless(CCPHYLO_CONFIGURED, "CCPhylo not configured.") + def test_ccphylo_upgma_solver(self): + # UPGMA Solver + upgma_tree = self.basic_tree.copy() + self.upgma_solver.solve(upgma_tree) + + # CCPhylo Fast UPGMA Solver + ccphylo_upgma_tree = self.basic_tree.copy() + self.ccphylo_upgma_solver.solve(ccphylo_upgma_tree) + + # test for expected number of edges + self.assertEqual(len(upgma_tree.edges), len(ccphylo_upgma_tree.edges)) + + + triplets = itertools.combinations(["a", "c", "d", "e"], 3) + for triplet in triplets: + expected_triplet = find_triplet_structure(triplet, + upgma_tree.get_tree_topology()) + observed_triplet = find_triplet_structure(triplet, + ccphylo_upgma_tree.get_tree_topology()) + self.assertEqual(expected_triplet, observed_triplet) + + @unittest.skipUnless(CCPHYLO_CONFIGURED, "CCPhylo not configured.") + def test_collapse_mutationless_edges_ccphylo(self): + # NJ Solver + nj_tree = self.basic_tree.copy() + self.nj_solver.solve(nj_tree, collapse_mutationless_edges=True) + + # Fast NJ Solver + ccphylo_nj_tree = self.basic_tree.copy() + self.ccphylo_nj_solver.solve(ccphylo_nj_tree, + collapse_mutationless_edges=True) + + # test for expected number of edges + self.assertEqual(len(nj_tree.edges), len(ccphylo_nj_tree.edges)) + + triplets = itertools.combinations(["a", "c", "d", "e"], 3) + for triplet in triplets: + expected_triplet = find_triplet_structure(triplet, + nj_tree.get_tree_topology()) + observed_triplet = find_triplet_structure(triplet, + ccphylo_nj_tree.get_tree_topology()) + self.assertEqual(expected_triplet, observed_triplet) + + @unittest.skipUnless(CCPHYLO_CONFIGURED, "CCPhylo not configured.") + def test_duplicate_sample_ccphylo(self): + # NJ Solver + nj_tree = self.duplicate_tree.copy() + self.nj_solver.solve(nj_tree) + + # Fast NJ Solver + ccphylo_nj_tree = self.duplicate_tree.copy() + self.ccphylo_nj_solver.solve(ccphylo_nj_tree) + + # test for expected number of edges + self.assertEqual(len(nj_tree.edges), len(ccphylo_nj_tree.edges)) + + triplets = itertools.combinations(["a", "b", "c", "d", "e", "f"], 3) + for triplet in triplets: + expected_triplet = find_triplet_structure(triplet, + nj_tree.get_tree_topology()) + observed_triplet = find_triplet_structure(triplet, + ccphylo_nj_tree.get_tree_topology()) + self.assertEqual(expected_triplet, observed_triplet) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/solver_tests/dissimilarity_functions_test.py b/test/solver_tests/dissimilarity_functions_test.py index e3209c57..e9db7a82 100755 --- a/test/solver_tests/dissimilarity_functions_test.py +++ b/test/solver_tests/dissimilarity_functions_test.py @@ -6,6 +6,7 @@ from unittest import mock import numpy as np +import pandas as pd from cassiopeia.solver import dissimilarity_functions from cassiopeia.solver import solver_utilities @@ -298,6 +299,27 @@ def test_hamming_distance_ignore_missing(self): self.assertEqual(distance, 0) + def test_save_dissimilarity_as_phylip(self): + # Create a sample dissimilarity map + data = [[0.0, 0.5, 0.7], [0.5, 0.0, 0.3], [0.7, 0.3, 0.0]] + index = ['A', 'B', 'C'] + dissimilarity_map = pd.DataFrame(data, index=index) + + # Expected content in the mock file + expected_content = ("3\n" + "A\t0.0000\n" + "B\t0.5000\t0.0000\n" + "C\t0.7000\t0.3000\t0.0000\n") + + # Mock the open function to use a mock file object + with mock.patch("builtins.open", mock.mock_open()) as mock_file: + solver_utilities.save_dissimilarity_as_phylip(dissimilarity_map, + "dummy_path") + mock_file.assert_called_once_with("dummy_path", "w") + mock_file().write.assert_called() + self.assertIn(expected_content, "".join(call[0][0] for + call in mock_file().write.call_args_list)) + if __name__ == "__main__": unittest.main()