Skip to content

Commit

Permalink
now reporting actual hit and the hit that gets counted as top hit per…
Browse files Browse the repository at this point in the history
… ddns group
  • Loading branch information
aineniamh committed Feb 28, 2024
1 parent 03a662f commit 628b907
Showing 1 changed file with 61 additions and 58 deletions.
119 changes: 61 additions & 58 deletions piranha/analysis/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,23 @@ def gather_filter_reads_by_length(dir_in,barcode,reads_out,config):
for reads_in in f:
if reads_in.endswith(".gz") or reads_in.endswith(".gzip"):
with gzip.open(os.path.join(dir_in,reads_in), "rt") as handle:
for record in SeqIO.parse(handle, "fastq"):
for record in SeqIO.parse(handle, KEY_FASTQ):
total_reads +=1
length = len(record)
if length > int(config[KEY_MIN_READ_LENGTH]) and length < int(config[KEY_MAX_READ_LENGTH]):
fastq_records.append(record)


elif reads_in.endswith(".fastq") or reads_in.endswith(".fq"):
for record in SeqIO.parse(os.path.join(dir_in,reads_in),"fastq"):
for record in SeqIO.parse(os.path.join(dir_in,reads_in),KEY_FASTQ):
total_reads +=1
length = len(record)
if length > int(config[KEY_MIN_READ_LENGTH]) and length < int(config[KEY_MAX_READ_LENGTH]):
fastq_records.append(record)

print(green(f"Total reads {barcode}:"),total_reads)
print(green(f"Total passed reads {barcode}:"),len(fastq_records))
SeqIO.write(fastq_records,fw, "fastq")
SeqIO.write(fastq_records,fw, KEY_FASTQ)

def parse_match_field(description,reference_match_field):
for item in str(description).split(" "):
Expand Down Expand Up @@ -89,43 +89,43 @@ def make_match_field_to_reference_group_map(ref_map):
def parse_line(line):
values = {}
tokens = line.rstrip("\n").split("\t")
values["read_name"], values["read_len"] = tokens[:2]
values[KEY_READ_NAME], values[KEY_READ_LEN] = tokens[:2]

values["read_hit_start"] = int(tokens[2])
values["read_hit_end"] = int(tokens[3])
values["direction"] = tokens[4]
values["ref_hit"], values["ref_len"], values["coord_start"], values["coord_end"], values["matches"], values["aln_block_len"],values["map_quality"] = tokens[5:12]
values[KEY_READ_HIT_START] = int(tokens[2])
values[KEY_READ_HIT_END] = int(tokens[3])
values[KEY_DIRECTION] = tokens[4]
values[KEY_REFERENCE_HIT], values[KEY_REFERENCE_LEN], values[KEY_COORD_START], values[KEY_COORD_END], values[KEY_MATCHES], values[KEY_ALN_BLOCK_LEN],values[KEY_MAP_QUALITY] = tokens[5:12]

values["ref_len"] = int(values["ref_len"])
values["aln_block_len"] = int(values["aln_block_len"])
values[KEY_REFERENCE_LEN] = int(values[KEY_REFERENCE_LEN])
values[KEY_ALN_BLOCK_LEN] = int(values[KEY_ALN_BLOCK_LEN])

return values

def add_to_hit_dict(hits,mapping,min_map_len,min_map_quality,unmapped):
status,description = "",""
if mapping["direction"] == "+":
start = mapping["read_hit_start"]
end = mapping["read_hit_end"]
elif mapping["direction"] == "-":
start = mapping["read_hit_end"]
end = mapping["read_hit_start"]
if mapping[KEY_DIRECTION] == "+":
start = mapping[KEY_READ_HIT_START]
end = mapping[KEY_READ_HIT_END]
elif mapping[KEY_DIRECTION] == "-":
start = mapping[KEY_READ_HIT_END]
end = mapping[KEY_READ_HIT_START]
else:
unmapped+=1
status = "unmapped"
status = KEY_UNMAPPED

if not status:
if int(mapping["aln_block_len"]) > min_map_len:
if int(mapping["map_quality"]) > min_map_quality:
hits[mapping["ref_hit"]].add((mapping["read_name"],start,end,mapping["aln_block_len"]))
status = "mapped"
description = f"MAPQ:{mapping['map_quality']} ALN_LEN:{mapping['aln_block_len']}"
if int(mapping[KEY_ALN_BLOCK_LEN]) > min_map_len:
if int(mapping[KEY_MAP_QUALITY]) > min_map_quality:
hits[mapping[KEY_REFERENCE_HIT]].add((mapping[KEY_READ_NAME],mapping[KEY_REFERENCE_HIT],start,end,mapping[KEY_ALN_BLOCK_LEN]))
status = KEY_MAPPED
description = f"MAPQ:{mapping[KEY_MAP_QUALITY]} ALN_LEN:{mapping[KEY_ALN_BLOCK_LEN]}"
else:
unmapped+=1
status = "filtered"
status = KEY_FILTERED
description = "mapping quality too low"
else:
unmapped+=1
status = "filtered"
status = KEY_FILTERED
description = "alignment block too short"

return unmapped,status,description
Expand All @@ -150,7 +150,7 @@ def group_hits(paf_file,
same_read_row = []
last_mapping = None
with open(mapping_filter_file,"w") as fw:
writer = csv.DictWriter(fw, fieldnames=["read_name","status","description"],lineterminator='\n')
writer = csv.DictWriter(fw, fieldnames=[KEY_READ_NAME,KEY_STATUS,KEY_DESCRIPTION],lineterminator='\n')
writer.writeheader()
with open(paf_file, "r") as f:
for l in f:
Expand All @@ -159,12 +159,12 @@ def group_hits(paf_file,

if not current_readname:
#so the very first read
current_readname = mapping["read_name"]
current_readname = mapping[KEY_READ_NAME]
same_read_row.append(mapping)
continue

#second read until the end
if mapping["read_name"] == current_readname:
if mapping[KEY_READ_NAME] == current_readname:
same_read_row.append(mapping)
else:
total_reads +=1
Expand All @@ -173,28 +173,28 @@ def group_hits(paf_file,

if len(same_read_row) >1:
#chimeric/multimapped reads
h = [i["ref_hit"] for i in same_read_row]
h = [i[KEY_REFERENCE_HIT] for i in same_read_row]
h = "|".join(sorted(h))
multi_hits[h]+=1
ambiguous +=1
description += f"; ambiguous reference hit: {h}"

row = {"read_name":current_readname,
"status":status,
"description":description}
row = {KEY_READ_NAME:current_readname,
KEY_STATUS:status,
KEY_DESCRIPTION:description}
writer.writerow(row)

current_readname = mapping["read_name"]
current_readname = mapping[KEY_READ_NAME]
same_read_row = [mapping]

total_reads +=1

first_hit = same_read_row[0]
unmapped,status,description = add_to_hit_dict(hits,first_hit,min_aln_block,min_map_quality,unmapped)

row = {"read_name":current_readname,
"status":status,
"description":description}
row = {KEY_READ_NAME:current_readname,
KEY_STATUS:status,
KEY_DESCRIPTION:description}
writer.writerow(row)

ref_group_hits = collections.defaultdict(set)
Expand Down Expand Up @@ -236,10 +236,10 @@ def write_out_report(hits,ref_group_ref,csv_out,unmapped,total_reads,barcode):

unmapped_row = {
KEY_BARCODE:barcode,
KEY_REFERENCE:"unmapped",
KEY_REFERENCE:KEY_UNMAPPED,
KEY_NUM_READS:unmapped,
KEY_PERCENT: pcent_unmapped,
KEY_REFERENCE_GROUP:"unmapped"}
KEY_REFERENCE_GROUP:KEY_UNMAPPED}
writer.writerow(unmapped_row)

for ref_group in hits:
Expand All @@ -253,29 +253,32 @@ def write_out_report(hits,ref_group_ref,csv_out,unmapped,total_reads,barcode):
KEY_REFERENCE_GROUP:ref_group}
writer.writerow(mapped_row)

def write_out_hits(hits,ref_group_ref,outfile):

def write_out_hits(ref_group_hits,ref_group_ref,outfile):
with open(outfile,"w") as fw:
writer = csv.DictWriter(fw, lineterminator="\n",fieldnames=["read_name","hit","start","end","aln_block_len"])
writer = csv.DictWriter(fw, lineterminator="\n",fieldnames=[KEY_READ_NAME,KEY_HIT,KEY_REFERENCE_HIT,KEY_REFERENCE_GROUP,KEY_START,KEY_END,KEY_ALN_BLOCK_LEN])
writer.writeheader()
for ref_group in hits:
hit_info = hits[ref_group]
for ref_group in ref_group_hits:
hit_info = ref_group_hits[ref_group]
reference = ref_group_ref[ref_group]
for read in hit_info:
name,start,end,aln_len = read
row = {"read_name":name,
"hit":reference,
"start":start,
"end":end,
"aln_block_len":aln_len}
name,ref_hit,start,end,aln_len = read
row = {KEY_READ_NAME:name,
KEY_HIT:reference,
KEY_REFERENCE_HIT:ref_hit,
KEY_REFERENCE_GROUP:ref_group,
KEY_START:start,
KEY_END:end,
KEY_ALN_BLOCK_LEN:aln_len}
writer.writerow(row)

def write_out_multi_mapped(multi_out,multi_hits):

with open(multi_out,"w") as fw:
writer = csv.DictWriter(fw, fieldnames=["references","multihit_count"],lineterminator="\n")
writer = csv.DictWriter(fw, fieldnames=[KEY_REFERENCES,KEY_MULTIHIT_COUNT],lineterminator="\n")
writer.writeheader()
for i in multi_hits:
row= {"references":i,"multihit_count":multi_hits[i]}
row= {KEY_REFERENCES:i,KEY_MULTIHIT_COUNT:multi_hits[i]}
writer.writerow(row)


Expand Down Expand Up @@ -332,15 +335,15 @@ def parse_paf_file(paf_file,
writer.writeheader()

with open(multi_out,"w") as fw:
writer = csv.DictWriter(fw, fieldnames=["references","multihit_count"],lineterminator="\n")
writer = csv.DictWriter(fw, fieldnames=[KEY_REFERENCES,KEY_MULTIHIT_COUNT],lineterminator="\n")
writer.writeheader()

with open(hits_out,"w") as fw:
writer = csv.DictWriter(fw, lineterminator="\n",fieldnames=["read_name","hit","start","end","aln_block_len"])
writer = csv.DictWriter(fw, lineterminator="\n",fieldnames=[KEY_READ_NAME,KEY_HIT,KEY_START,KEY_END,KEY_ALN_BLOCK_LEN])
writer.writeheader()

with open(mapping_filter_out,"w") as fw:
writer = csv.DictWriter(fw, lineterminator="\n",fieldnames=["read_name","status","description"])
writer = csv.DictWriter(fw, lineterminator="\n",fieldnames=[KEY_READ_NAME,KEY_STATUS,KEY_DESCRIPTION])
writer.writeheader()


Expand Down Expand Up @@ -394,7 +397,7 @@ def diversity_report(input_files,csv_out,summary_out,ref_file,config):
for barcode in refs_out:
refs = refs_out[barcode]

refs = [i for i in refs if i != "unmapped"]
refs = [i for i in refs if i != KEY_UNMAPPED]
config[barcode]=refs
if refs:
config[KEY_BARCODES].append(barcode)
Expand All @@ -407,13 +410,13 @@ def check_which_refs_to_write(input_csv,min_reads,min_pcent):
reader = csv.DictReader(f)
for row in reader:
if int(row[KEY_NUM_READS]) >= min_reads and float(row[KEY_PERCENT]) >= min_pcent:
if row[KEY_REFERENCE] != "unmapped":
if row[KEY_REFERENCE] != KEY_UNMAPPED:
to_write.add(row[KEY_REFERENCE])
print(f"{row[KEY_REFERENCE_GROUP]}\t{row[KEY_NUM_READS]} reads\t{row[KEY_PERCENT]}% of sample")
return list(to_write)

def write_out_fastqs(input_csv,input_hits,input_fastq,outdir,primer_length,config):
seq_index = SeqIO.index(input_fastq, "fastq")
seq_index = SeqIO.index(input_fastq, KEY_FASTQ)

to_write = check_which_refs_to_write(input_csv,config[KEY_MIN_READS],config[KEY_MIN_PCENT])
handle_dict = {}
Expand All @@ -426,15 +429,15 @@ def write_out_fastqs(input_csv,input_hits,input_fastq,outdir,primer_length,confi
reader = csv.DictReader(f)
for row in reader:
try:
read_name = row["read_name"]
read_name = row[KEY_READ_NAME]
record = seq_index[read_name]

hit = row["hit"]
hit = row[KEY_HIT]
handle = handle_dict[hit]

trimmed_record = record[primer_length:-primer_length]

SeqIO.write(trimmed_record,handle,"fastq")
SeqIO.write(trimmed_record,handle,KEY_FASTQ)
except:
not_written[hit]+=1

Expand Down

0 comments on commit 628b907

Please sign in to comment.