Skip to content

Commit

Permalink
More unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rrwick committed Dec 15, 2016
1 parent c048065 commit 876b7b7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 32 deletions.
85 changes: 77 additions & 8 deletions test/test_assembly_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,10 @@ def test_segments(self):
self.assertEqual(len(self.graph.segments), 336)

def test_forward_links(self):
link_count = 0
for link_list in self.graph.forward_links.values():
link_count += len(link_list)
self.assertEqual(link_count, 904)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 904)

def test_reverse_links(self):
link_count = 0
for link_list in self.graph.reverse_links.values():
link_count += len(link_list)
self.assertEqual(link_count, 904)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 904)

def test_links_match(self):
"""
Expand Down Expand Up @@ -126,6 +120,81 @@ def test_get_all_gfa_link_lines(self):
self.assertEqual(gfa_link_lines.count('\n'), 452)
self.assertEqual(gfa_link_lines.count('25M'), 452)

def test_filter_by_read_depth(self):
# A loop segment can be removed only when its depth drops below the threshold.
self.assertEqual(len(self.graph.segments), 336)
self.graph.filter_by_read_depth(0.5)
self.assertEqual(len(self.graph.segments), 336)
self.graph.segments[68].depth = 21.0
self.graph.filter_by_read_depth(0.5)
self.assertEqual(len(self.graph.segments), 336)
self.graph.segments[68].depth = 20.0
self.graph.filter_by_read_depth(0.5)
self.assertEqual(len(self.graph.segments), 335)

# A low-depth segment is only removed if its a dead end.
self.graph.segments[306].depth = 0.1
self.graph.filter_by_read_depth(0.5)
self.assertEqual(len(self.graph.segments), 335)
self.graph.remove_segments([273])
self.assertEqual(len(self.graph.segments), 334)
self.graph.filter_by_read_depth(0.5)
self.assertEqual(len(self.graph.segments), 333)

# def test_filter_homopolymer_loops(self):
# pass

def test_remove_segments(self):
self.assertEqual(len(self.graph.segments), 336)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 904)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 904)
self.graph.remove_segments([276])
self.assertEqual(len(self.graph.segments), 335)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 902)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 902)
self.graph.remove_segments([273])
self.assertEqual(len(self.graph.segments), 334)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 894)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 894)
self.graph.remove_segments([67, 108, 222, 297])
self.assertEqual(len(self.graph.segments), 330)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 870)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 870)

def test_remove_small_components(self):
self.assertEqual(len(self.graph.segments), 336)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 904)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 904)
self.graph.remove_small_components(5000, 0)
self.assertEqual(len(self.graph.segments), 336)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 904)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 904)
self.graph.remove_small_components(6000, 0)
self.assertEqual(len(self.graph.segments), 335)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 902)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 902)
self.graph.remove_small_components(180000, 0)
self.assertEqual(len(self.graph.segments), 335)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 902)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 902)
self.graph.remove_small_components(190000, 0)
self.assertEqual(len(self.graph.segments), 0)
self.assertEqual(sum(len(x) for x in self.graph.forward_links.values()), 0)
self.assertEqual(sum(len(x) for x in self.graph.reverse_links.values()), 0)

# def test_remove_small_dead_ends(self):
# pass

# def test_merge_all_possible(self):
# pass

# def test_merge_simple_path(self):
# pass

# def test_get_mean_path_depth(self):
# pass


# class TestAssemblyGraphFunctionsGfa(unittest.TestCase):
# """
# Tests various AssemblyGraph functions on a graph loaded from a GFA file.
Expand Down
24 changes: 0 additions & 24 deletions unicycler/assembly_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,30 +363,6 @@ def get_all_gfa_link_lines(self):
gfa_link_lines += self.gfa_link_line(start, end)
return gfa_link_lines

def get_fastg_header_with_links(self, segment, positive):
"""
Returns a full SPAdes-style FASTG header for a segment, including the leading '>', all of
the links, the trailing ';' and a newline.
"""
number = segment.number
if not positive:
number *= -1
header = '>' + segment.get_fastg_header(positive)
if number in self.forward_links:
header += ':'
next_segment_headers = []
for next_num in self.forward_links[number]:
if next_num < 0:
next_positive = False
next_num *= -1
else:
next_positive = True
next_segment = self.segments[next_num]
next_segment_headers.append(next_segment.get_fastg_header(next_positive))
header += ','.join(next_segment_headers)
header += ';\n'
return header

def filter_by_read_depth(self, relative_depth_cutoff):
"""
This function removes segments from the graph based on a relative depth cutoff. Segments
Expand Down

0 comments on commit 876b7b7

Please sign in to comment.