diff --git a/clipkit/msa.py b/clipkit/msa.py index a72beec..c1a2559 100644 --- a/clipkit/msa.py +++ b/clipkit/msa.py @@ -102,12 +102,20 @@ def is_any_entry_sequence_only_gaps(self) -> tuple[bool, Union[str, None]]: def trim( self, - mode: TrimmingMode, + mode: TrimmingMode=TrimmingMode.smart_gap, gap_threshold=None, + site_positions_to_trim=None, ) -> np.array: - self._site_positions_to_trim = self.determine_site_positions_to_trim( - mode, gap_threshold - ) + if site_positions_to_trim is not None: + if isinstance(site_positions_to_trim, list): + site_positions_to_trim = np.array(site_positions_to_trim) + if not isinstance(site_positions_to_trim, np.ndarray): + raise ValueError("site_positions_to_trim must be a list or np array") + self._site_positions_to_trim = site_positions_to_trim + else: + self._site_positions_to_trim = self.determine_site_positions_to_trim( + mode, gap_threshold + ) self._site_positions_to_keep = np.delete( np.arange(self._original_length), self._site_positions_to_trim ) diff --git a/clipkit/version.py b/clipkit/version.py index 9aa3f90..58039f5 100644 --- a/clipkit/version.py +++ b/clipkit/version.py @@ -1 +1 @@ -__version__ = "2.1.0" +__version__ = "2.1.1" diff --git a/tests/unit/test_msa.py b/tests/unit/test_msa.py new file mode 100644 index 0000000..7649c93 --- /dev/null +++ b/tests/unit/test_msa.py @@ -0,0 +1,51 @@ +import pytest +import numpy as np + +from Bio import AlignIO +from clipkit.msa import MSA + +def get_biopython_msa(file_path, file_format="fasta"): + return AlignIO.read(open(file_path), file_format) + + +class TestMSA(object): + def test_clipkit_msa_from_bio_msa(self): + bio_msa = get_biopython_msa("tests/unit/examples/simple.fa") + msa = MSA.from_bio_msa(bio_msa) + assert msa.header_info == [{'id': '1', 'name': '1', 'description': '1'}, {'id': '2', 'name': '2', 'description': '2'}, {'id': '3', 'name': '3', 'description': '3'}, {'id': '4', 'name': '4', 'description': '4'}, {'id': '5', 'name': '5', 'description': '5'}] + expected_seq_records = np.array([ + ['A', '-', 'G', 'T', 'A', 'T'], + ['A', '-', 'G', '-', 'A', 'T'], + ['A', '-', 'G', '-', 'T', 'A'], + ['A', 'G', 'A', '-', 'T', 'A'], + ['A', 'C', 'a', '-', 'T', '-'] + ]) + np.testing.assert_equal(msa.seq_records, expected_seq_records) + + def test_trim_by_provided_site_positions_np_array(self): + bio_msa = get_biopython_msa("tests/unit/examples/simple.fa") + msa = MSA.from_bio_msa(bio_msa) + sites_to_trim = np.array([1, 4]) + msa.trim(site_positions_to_trim=sites_to_trim) + expected_sites_kept = np.array([ + ['A', 'G', 'T', 'T'], + ['A', 'G', '-', 'T'], + ['A', 'G', '-', 'A'], + ['A', 'A', '-', 'A'], + ['A', 'a', '-', '-'] + ]) + np.testing.assert_equal(msa.sites_kept, expected_sites_kept) + + def test_trim_by_provided_site_positions_list(self): + bio_msa = get_biopython_msa("tests/unit/examples/simple.fa") + msa = MSA.from_bio_msa(bio_msa) + sites_to_trim = [1, 4] + msa.trim(site_positions_to_trim=sites_to_trim) + expected_sites_kept = np.array([ + ['A', 'G', 'T', 'T'], + ['A', 'G', '-', 'T'], + ['A', 'G', '-', 'A'], + ['A', 'A', '-', 'A'], + ['A', 'a', '-', '-'] + ]) + np.testing.assert_equal(msa.sites_kept, expected_sites_kept)