diff --git a/bin/extract_taxa_from_reads.py b/bin/extract_taxa_from_reads.py index a87f849..6d27270 100755 --- a/bin/extract_taxa_from_reads.py +++ b/bin/extract_taxa_from_reads.py @@ -34,10 +34,8 @@ def median(l): return l[i] -def load_from_taxonomy(taxonomy_dir): +def load_from_taxonomy(taxonomy_dir, parents, children): taxonomy = os.path.join(taxonomy_dir, "nodes.dmp") - parents = {} - children = defaultdict(set) try: with open(taxonomy, "r") as f: for line in f: @@ -63,9 +61,7 @@ def parse_depth(name): return depth -def infer_hierarchy(report_file): - parents = {} - children = defaultdict(set) +def infer_hierarchy(report_file, parents, children): hierarchy = [] with open(report_file, "r") as f: for line in f: @@ -97,7 +93,8 @@ def infer_hierarchy(report_file): if len(hierarchy) > 1: parent = hierarchy[-2] - parents[ncbi] = parent + if ncbi not in parents: + parents[ncbi] = parent children[parent].add(ncbi) return parents, children @@ -564,11 +561,11 @@ def main(): target_ranks = [] sys.stderr.write("Loading hierarchy\n") - parent, children = None, None + parent = {} + children = defaultdict(set) if args.taxonomy: - parent, children = load_from_taxonomy(args.taxonomy) - else: - parent, children = infer_hierarchy(args.report_file) + parent, children = load_from_taxonomy(args.taxonomy, parent, children) + parent, children = infer_hierarchy(args.report_file, parent, children) # get taxids to extract sys.stderr.write("Loading kreport\n")