From 876b7b754a9757de83040315e2d4cfe0a5f9aefb Mon Sep 17 00:00:00 2001 From: Ryan Wick Date: Thu, 15 Dec 2016 22:42:52 +1100 Subject: [PATCH] More unit tests --- test/test_assembly_graph.py | 85 +++++++++++++++++++++++++++++++++---- unicycler/assembly_graph.py | 24 ----------- 2 files changed, 77 insertions(+), 32 deletions(-) diff --git a/test/test_assembly_graph.py b/test/test_assembly_graph.py index ca2e42fc..8dd8e868 100644 --- a/test/test_assembly_graph.py +++ b/test/test_assembly_graph.py @@ -24,16 +24,10 @@ def test_segments(self): self.assertEqual(len(self.graph.segments), 336) def test_forward_links(self): - link_count = 0 - for link_list in self.graph.forward_links.values(): - link_count += len(link_list) - self.assertEqual(link_count, 904) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 904) def test_reverse_links(self): - link_count = 0 - for link_list in self.graph.reverse_links.values(): - link_count += len(link_list) - self.assertEqual(link_count, 904) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 904) def test_links_match(self): """ @@ -126,6 +120,81 @@ def test_get_all_gfa_link_lines(self): self.assertEqual(gfa_link_lines.count('\n'), 452) self.assertEqual(gfa_link_lines.count('25M'), 452) + def test_filter_by_read_depth(self): + # A loop segment can be removed only when its depth drops below the threshold. + self.assertEqual(len(self.graph.segments), 336) + self.graph.filter_by_read_depth(0.5) + self.assertEqual(len(self.graph.segments), 336) + self.graph.segments[68].depth = 21.0 + self.graph.filter_by_read_depth(0.5) + self.assertEqual(len(self.graph.segments), 336) + self.graph.segments[68].depth = 20.0 + self.graph.filter_by_read_depth(0.5) + self.assertEqual(len(self.graph.segments), 335) + + # A low-depth segment is only removed if its a dead end. + self.graph.segments[306].depth = 0.1 + self.graph.filter_by_read_depth(0.5) + self.assertEqual(len(self.graph.segments), 335) + self.graph.remove_segments([273]) + self.assertEqual(len(self.graph.segments), 334) + self.graph.filter_by_read_depth(0.5) + self.assertEqual(len(self.graph.segments), 333) + + # def test_filter_homopolymer_loops(self): + # pass + + def test_remove_segments(self): + self.assertEqual(len(self.graph.segments), 336) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 904) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 904) + self.graph.remove_segments([276]) + self.assertEqual(len(self.graph.segments), 335) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 902) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 902) + self.graph.remove_segments([273]) + self.assertEqual(len(self.graph.segments), 334) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 894) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 894) + self.graph.remove_segments([67, 108, 222, 297]) + self.assertEqual(len(self.graph.segments), 330) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 870) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 870) + + def test_remove_small_components(self): + self.assertEqual(len(self.graph.segments), 336) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 904) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 904) + self.graph.remove_small_components(5000, 0) + self.assertEqual(len(self.graph.segments), 336) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 904) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 904) + self.graph.remove_small_components(6000, 0) + self.assertEqual(len(self.graph.segments), 335) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 902) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 902) + self.graph.remove_small_components(180000, 0) + self.assertEqual(len(self.graph.segments), 335) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 902) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 902) + self.graph.remove_small_components(190000, 0) + self.assertEqual(len(self.graph.segments), 0) + self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 0) + self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 0) + + # def test_remove_small_dead_ends(self): + # pass + + # def test_merge_all_possible(self): + # pass + + # def test_merge_simple_path(self): + # pass + + # def test_get_mean_path_depth(self): + # pass + + # class TestAssemblyGraphFunctionsGfa(unittest.TestCase): # """ # Tests various AssemblyGraph functions on a graph loaded from a GFA file. diff --git a/unicycler/assembly_graph.py b/unicycler/assembly_graph.py index 5a9489cc..1dbf1e83 100644 --- a/unicycler/assembly_graph.py +++ b/unicycler/assembly_graph.py @@ -363,30 +363,6 @@ def get_all_gfa_link_lines(self): gfa_link_lines += self.gfa_link_line(start, end) return gfa_link_lines - def get_fastg_header_with_links(self, segment, positive): - """ - Returns a full SPAdes-style FASTG header for a segment, including the leading '>', all of - the links, the trailing ';' and a newline. - """ - number = segment.number - if not positive: - number *= -1 - header = '>' + segment.get_fastg_header(positive) - if number in self.forward_links: - header += ':' - next_segment_headers = [] - for next_num in self.forward_links[number]: - if next_num < 0: - next_positive = False - next_num *= -1 - else: - next_positive = True - next_segment = self.segments[next_num] - next_segment_headers.append(next_segment.get_fastg_header(next_positive)) - header += ','.join(next_segment_headers) - header += ';\n' - return header - def filter_by_read_depth(self, relative_depth_cutoff): """ This function removes segments from the graph based on a relative depth cutoff. Segments