Skip to content

Commit

Permalink
Merge pull request #222 from polio-nanopore/wt-dev
Browse files Browse the repository at this point in the history
Large update to how initial processing paf files are parsed (implications for wt and npev calls)
  • Loading branch information
aineniamh authored Aug 2, 2024
2 parents f8e386e + 41fc279 commit 5742323
Show file tree
Hide file tree
Showing 20 changed files with 316 additions and 382 deletions.
2 changes: 1 addition & 1 deletion piranha/analysis/get_co_occurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_combinations(variants,read_fasta_file,reference,barcode,threshold):
sites = [int(i.split(":")[0]) for i in variant_list]

c = 100
for record in SeqIO.parse(read_fasta_file,"fasta"):
for record in SeqIO.parse(read_fasta_file,KEY_FASTA):

if not c:
break
Expand Down
4 changes: 2 additions & 2 deletions piranha/analysis/get_haplotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def write_haplotype_ref(ref,taxon,read_haplotypes,outdir):

def write_haplotype_fastq(reads,taxon,read_haplotypes,outdir):

reads = SeqIO.index(reads,"fastq")
reads = SeqIO.index(reads,KEY_FASTQ)
for h in read_haplotypes:
with open(os.path.join(outdir, f"{taxon}_{h}.fastq"),"w") as fw:
for read in read_haplotypes[h]:
record = reads[read]
SeqIO.write(record,fw,"fastq")
SeqIO.write(record,fw,KEY_FASTQ)


def get_haplotypes(fasta,vcf,reads,ref,out_haplotypes,outdir,taxon,min_reads,min_pcent):
Expand Down
4 changes: 2 additions & 2 deletions piranha/analysis/phylo_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def update_local_database(sample_sequences,detailed_csv,new_db_seqs,new_db_metad
with open(new_db_seqs,"w") as fw:
countnew = 0

for record in SeqIO.parse(sample_sequences, "fasta"):
for record in SeqIO.parse(sample_sequences, KEY_FASTA):
new_record = record
desc_list = new_record.description.split(" ")
write_record = True
Expand All @@ -211,7 +211,7 @@ def update_local_database(sample_sequences,detailed_csv,new_db_seqs,new_db_metad
new_desc_list = [i for i in desc_list if not i.startswith("barcode=")]
new_record.description = " ".join(new_desc_list)

SeqIO.write(new_record, fw, "fasta")
SeqIO.write(new_record, fw, KEY_FASTA)
countnew+=1
sample = record.id.split("|")[0]
record_ids[record.id] = sample
Expand Down
237 changes: 152 additions & 85 deletions piranha/analysis/preprocessing.py

Large diffs are not rendered by default.

56 changes: 55 additions & 1 deletion piranha/analysis/stool_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import collections
from Bio import SeqIO
import csv
import json
from itertools import groupby

from piranha.utils.config import *

Expand Down Expand Up @@ -102,4 +104,56 @@ def get_sample(barcodes_csv,barcode):
for row in reader:
if row[KEY_BARCODE] == barcode:
sample = row[KEY_SAMPLE]
return sample
return sample




def group_consecutive_sites(lst):
out = []
for _, g in groupby(enumerate(lst), lambda k: k[0] - k[1]):
start = next(g)[1]
end = list(v for _, v in g) or [start]
out.append(range(start, end[-1] + 1))
return out

def get_mask_dict(mask_file):
mask_dict = {}
with open(mask_file,"r") as f:
mask_json = json.load(f)

for ref in mask_json:

sites_to_mask = sorted(mask_json[ref])
mask_ranges = group_consecutive_sites(sites_to_mask)
mask_dict[ref] = mask_ranges
return mask_dict


def mask_low_coverage(mask_file, sequences,output):
mask_dict = get_mask_dict(mask_file)
records = 0
with open(output,"w") as fw:
for record in SeqIO.parse(sequences,"fasta"):
seq_id = record.id.split("|")[0]
if seq_id in mask_dict and mask_dict[seq_id]:
mask_ranges = mask_dict[seq_id]
n_count = 0
new_seq = str(record.seq)
for site in mask_ranges:

start,stop = site[0]-1,site[-1]+1

if site[0] == 0:
length = stop
n_count+=length
new_seq = ("N"*length) + new_seq[stop:]
else:
length = stop-start
new_seq = new_seq[:start] + ("N"*length) + new_seq[stop:]
n_count+=length

n_diff = new_seq.count("N") - str(record.seq).count("N")
fw.write(f">{record.description}\n{new_seq}\n")
else:
fw.write(f">{record.description}\n{record.seq}\n")
6 changes: 3 additions & 3 deletions piranha/analysis/variation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def parse_variant_file(var_file):
def ref_dict_maker(ref_fasta):
ref_dict = {}
ref_fasta = pysam.FastaFile(ref_fasta)
for idx, base in enumerate(ref_fasta.fetch(ref_fasta.references[0])):
for idx, base in enumerate(ref_fasta.fetch(ref_fasta.references[0]), start=0):
ref_dict[idx] = base

return ref_dict
Expand All @@ -60,7 +60,7 @@ def non_ref_prcnt_calc(pos,pileup_dict,ref_dict):
else:
non_ref_prcnt = round((100 - ((ref_count / total) * 100)), 2)

return non_ref_prcnt
return non_ref_prcnt,total



Expand Down Expand Up @@ -138,7 +138,7 @@ def pileupper(bamfile,ref_dict,var_dict,base_q=13):
pileup_dict["T reads"] = T_counter
pileup_dict["G reads"] = G_counter
pileup_dict["- reads"] = del_counter
pileup_dict["Percentage"] = non_ref_prcnt_calc(pileupcolumn.pos,pileup_dict,ref_dict)
pileup_dict["Percentage"],pileup_dict["Total"] = non_ref_prcnt_calc(pileupcolumn.pos,pileup_dict,ref_dict)
pileup_dict["Ref base"] = ref_dict[pileupcolumn.pos]
variation_info.append(pileup_dict)

Expand Down
6 changes: 3 additions & 3 deletions piranha/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def main(sysargs = sys.argv[1:]):

analysis_group = parser.add_argument_group('Analysis options')
analysis_group.add_argument("-s","--sample-type",action="store",help=f"Specify sample type. Options: `stool`, `environmental`. Default: `{VALUE_SAMPLE_TYPE}`")
analysis_group.add_argument("-m","--analysis-mode",action="store",help=f"Specify analysis mode to run. Options: `vp1`, `wg`. Default: `{VALUE_ANALYSIS_MODE}`")
analysis_group.add_argument("-m","--analysis-mode",action="store",help=f"Specify analysis mode to run, for preconfigured defaults. Options: `vp1`, `wg`. Default: `{VALUE_ANALYSIS_MODE}`")
analysis_group.add_argument("--medaka-model",action="store",help=f"Medaka model to run analysis using. Default: {VALUE_DEFAULT_MEDAKA_MODEL}")
analysis_group.add_argument("--medaka-list-models",action="store_true",help="List available medaka models and exit.")
analysis_group.add_argument("-q","--min-map-quality",action="store",type=int,help=f"Minimum mapping quality. Default: {VALUE_MIN_MAP_QUALITY}")
analysis_group.add_argument("-q","--min-map-quality",action="store",type=int,help=f"Minimum mapping quality. Range 0 to 60, however 0 can imply a multimapper. Default: {VALUE_MIN_MAP_QUALITY}")
analysis_group.add_argument("-n","--min-read-length",action="store",type=int,help=f"Minimum read length. Default: {READ_LENGTH_DICT[VALUE_ANALYSIS_MODE][0]}")
analysis_group.add_argument("-x","--max-read-length",action="store",type=int,help=f"Maximum read length. Default: {READ_LENGTH_DICT[VALUE_ANALYSIS_MODE][1]}")
analysis_group.add_argument("-d","--min-read-depth",action="store",type=int,help=f"Minimum read depth required for consensus generation. Default: {VALUE_MIN_READS}")
Expand Down Expand Up @@ -128,7 +128,7 @@ def main(sysargs = sys.argv[1:]):
misc.add_check_valid_arg(KEY_ORIENTATION,args.orientation,VALID_ORIENTATION,config)

# grabs the snakefile
snakefile = data_install_checks.get_snakefile(thisdir,config[KEY_ANALYSIS_MODE])
snakefile = data_install_checks.get_snakefile(thisdir,"main")

# Checks medaka options if non default values used.
analysis_arg_parsing.medaka_options_parsing(args.medaka_model,args.medaka_list_models,config)
Expand Down
6 changes: 3 additions & 3 deletions piranha/data/report.mako
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@
],
"layer":[
{
"transform": [{"calculate": "datum[filterBy]", "as": "EV reads present"}],
"transform": [{"calculate": "datum[filterBy]", "as": "Reads present"}],
"mark": {"type":"circle","size":500},
"encoding": {
"x": {"field": "x", "type": "nominal",
Expand All @@ -905,7 +905,7 @@
"labelFont":"Helvetica Neue",
"labelFontSize":18
}},
"fill": {"field": "EV reads present",
"fill": {"field": "Reads present",
"scale": {"range": ["#e68781", "#48818d", "#b2b2b2"]},
"sort": ["Present","Absent","N/A"]
},
Expand Down Expand Up @@ -1049,7 +1049,7 @@
<div id="citation_box" class="info_box">
<p>${LANGUAGE_CONFIG["53"]}</p>
<p>
<strong>O’Toole Á, Colquhoun R, Ansley C, Troman C, Maloney D, Vance Z, Akello J, Bujaki E, Majumdar M, Khurshid A, Arshad Y, Alam MM, Martin J, Shaw A, Grassly N, Rambaut A</strong> (2023) Automated detection and classification of polioviruses from nanopore sequencing reads using piranha. <i>bioRxiv</i> <a style='color:#e68781' href="https://doi.org/10.1101/2023.09.05.556319">https://doi.org/10.1101/2023.09.05.556319</a>
<strong>O’Toole Á, Colquhoun R, Ansley C, Troman C, Maloney D, Vance Z, Akello J, Bujaki E, Majumdar M, Khurshid A, Arshad Y, Alam MM, Martin J, Shaw A, Grassly N, Rambaut A</strong> (2023) Automated detection and classification of polioviruses from nanopore sequencing reads using piranha. <i>Virus Evolution</i> <a style='color:#e68781' href="https://doi.org/10.1093/ve/veae023">https://doi.org/10.1093/ve/veae023</a>
</p>
</div>
<footer class="page-footer">
Expand Down
12 changes: 7 additions & 5 deletions piranha/input_parsing/input_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def parse_barcodes_csv(barcodes_csv,config):
# total = 0
# seq_ids = set()
# try:
# for record in SeqIO.parse(supplementary_sequences,"fasta"):
# for record in SeqIO.parse(supplementary_sequences,KEY_FASTA):
# seq_ids.add(record.id)

# total +=1
Expand Down Expand Up @@ -142,15 +142,17 @@ def qc_supplementary_metadata_file(supplementary_metadata,config):


def parse_fasta_file(supplementary_datadir,supp_file,seq_records,no_reference_group,total_seqs,seq_info,config):
for record in SeqIO.parse(os.path.join(supplementary_datadir,supp_file),"fasta"):
for record in SeqIO.parse(os.path.join(supplementary_datadir,supp_file),KEY_FASTA):
total_seqs["total"] +=1
ref_group = ""
for field in record.description.split(" "):
if field.startswith(VALUE_REFERENCE_GROUP_FIELD):
ref_group = field.split("=")[1]
# print(VALUE_REFERENCE_GROUP_FIELD, "ref group", ref_group)

if ref_group not in config[KEY_REFERENCES_FOR_CNS]:
no_reference_group.add(record.id)
# print(record.id)
else:
total_seqs[ref_group]+=1
seq_records.append(record)
Expand Down Expand Up @@ -202,7 +204,7 @@ def gather_supplementary_data(supplementary_datadir,supplementary_sequences,supp

check_there_are_seqs(total_seqs,supplementary_datadir,no_reference_group,config)

SeqIO.write(seq_records,fw, "fasta")
SeqIO.write(seq_records,fw, KEY_FASTA)

supplementary_metadata_header = set()

Expand Down Expand Up @@ -362,7 +364,7 @@ def parse_input_group(barcodes_csv,readdir,reference_sequences,reference_group_f
seq_ids = collections.Counter()
ref_group_field_in_headers = True
ref_group_values = set()
for record in SeqIO.parse(config[KEY_REFERENCE_SEQUENCES],"fasta"):
for record in SeqIO.parse(config[KEY_REFERENCE_SEQUENCES],KEY_FASTA):
ref_group_value = parse_ref_group_values(record.description,config[KEY_REFERENCE_GROUP_FIELD])
ref_group_values.add(ref_group_value)

Expand Down Expand Up @@ -456,7 +458,7 @@ def control_group_parsing(positive_control, negative_control, positive_reference
print(f"- {neg}")

if config[KEY_INCLUDE_POSITIVE_REFERENCES]:
refs = SeqIO.index(config[KEY_REFERENCE_SEQUENCES],"fasta")
refs = SeqIO.index(config[KEY_REFERENCE_SEQUENCES],KEY_FASTA)
not_in = set()
for ref in config[KEY_POSITIVE_REFERENCES]:
if ref not in refs:
Expand Down
9 changes: 6 additions & 3 deletions piranha/scripts/piranha_consensus.smk
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ rule gather_merge_cns:
yaml = os.path.join(config[KEY_TEMPDIR],"consensus_config.yaml")
run:
sequences = collections.defaultdict(set)
variant_calls = collections.defaultdict(set)
ref_seqs = {}

for cns_file in input.cns:
Expand All @@ -113,19 +114,21 @@ rule gather_merge_cns:
haplodir = "/".join(cns_file.split("/")[:-1])+"_cns"

haplo_bam = os.path.join(haplodir, "calls_to_ref.bam")
for record in SeqIO.parse(cns_file, "fasta"):
haplo_vcf = os.path.join(haplodir, "medaka.vcf")
for record in SeqIO.parse(cns_file, KEY_FASTA):
print(record.id)
sequences[str(record.seq)].add(haplo_bam)
variant_calls[str(record.seq)].add(haplo_vcf)
ref_seqs[str(record.seq)] = record.id

ref_file_dict = {}
for ref_file in input.ref:
for record in SeqIO.parse(ref_file,"fasta"):
for record in SeqIO.parse(ref_file,KEY_FASTA):
ref_file_dict[record.id] = ref_file
# print(ref_file_dict)
cns_file_dict = {}
for cns_file in input.cns_cns:
for record in SeqIO.parse(cns_file,"fasta"):
for record in SeqIO.parse(cns_file,KEY_FASTA):
cns_file_dict[record.id] = cns_file
# print(cns_file_dict)

Expand Down
5 changes: 3 additions & 2 deletions piranha/scripts/piranha_curate.smk
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ rule all:
os.path.join(config[KEY_TEMPDIR],"variants.csv"),
os.path.join(config[KEY_TEMPDIR],"masked_variants.csv"),
expand(os.path.join(config[KEY_TEMPDIR],"snipit","{reference}.svg"), reference=REFERENCES)
# expand(os.path.join(config[KEY_TEMPDIR],"reference_analysis","{reference}.merged_cns.mask.tsv"), reference=REFERENCES)


# do this per cns

rule files:
params:
ref= os.path.join(config[KEY_TEMPDIR],"reference_analysis","{reference}.ref.fasta"),
cns = os.path.join(config[KEY_TEMPDIR],"reference_analysis","{reference}.merged_cns.fasta")

cns = os.path.join(config[KEY_TEMPDIR],"reference_analysis","{reference}.merged_cns.fasta"),
vcf = os.path.join(config[KEY_TEMPDIR],"variant_calls","{reference}.vcf")

rule join_cns_ref:
input:
Expand Down
6 changes: 3 additions & 3 deletions piranha/scripts/piranha_haplotype.smk
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ rule rasusa:
fastq= os.path.join(config[KEY_TEMPDIR],"reference_analysis","{reference}","haplotyping","downsample.fastq")
run:
ref_len = 0
for record in SeqIO.parse(input.ref,"fasta"):
for record in SeqIO.parse(input.ref,KEY_FASTA):
ref_len = len(record)
shell("rasusa -i {input.reads:q} -c {params.depth:q} " + f"-g {ref_len}b" + " -o {output.fastq:q}")

Expand Down Expand Up @@ -127,7 +127,7 @@ rule haplotype_qc:
minimum reads for the haplotype to be used
"""
partitions = parse_partition_file(input.partition)
seq_index = SeqIO.index(input.reads, "fastq")
seq_index = SeqIO.index(input.reads, KEY_FASTQ)
merge_info = collapse_close(input.flopp,config[KEY_MIN_HAPLOTYPE_DISTANCE],input.vcf)
with open(output.txt,"w") as fhaplo:
merged_haplo_count = 0
Expand All @@ -151,7 +151,7 @@ rule haplotype_qc:
record = seq_index[read]
records.append(record)

SeqIO.write(records,fw,"fastq")
SeqIO.write(records,fw,KEY_FASTQ)

haplo_ref = os.path.join(params.haplodir,f"{haplotype}.reference.fasta")
shell(f"cp {input.ref} {haplo_ref}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,18 @@ rule curate_sequences:
rule generate_variation_info:
input:
snakefile = os.path.join(workflow.current_basedir,"piranha_variation.smk"),
fasta = rules.curate_sequences.output.fasta,
yaml = rules.generate_consensus_sequences.output.yaml
yaml = rules.generate_consensus_sequences.output.yaml,
variants = rules.curate_sequences.output.csv,
fasta = os.path.join(config[KEY_TEMPDIR],"{barcode}","consensus_sequences.fasta")
params:
barcode = "{barcode}",
outdir = os.path.join(config[KEY_OUTDIR],"{barcode}"),
tempdir = os.path.join(config[KEY_TEMPDIR],"{barcode}")
threads: workflow.cores
log: os.path.join(config[KEY_TEMPDIR],"logs","{barcode}_variation.smk.log")
output:
json = os.path.join(config[KEY_TEMPDIR],"{barcode}","variation_info.json")
json = os.path.join(config[KEY_TEMPDIR],"{barcode}","variation_info.json"),
json_mask = os.path.join(config[KEY_TEMPDIR],"{barcode}","mask_info.json")
run:
# decide if we want 1 per haplotyde or 1 per ref group, will need mods either way
sample = get_sample(config[KEY_BARCODES_CSV],params.barcode)
Expand All @@ -143,10 +145,20 @@ rule generate_variation_info:
f"sample='{sample}' "
"--cores {threads} &> {log:q}")


rule mask_consensus_sequences:
input:
mask_json = rules.generate_variation_info.output.json_mask,
fasta = rules.curate_sequences.output.fasta
output:
fasta = os.path.join(config[KEY_TEMPDIR],"{barcode}","consensus_sequences.masked.fasta")
run:
mask_low_coverage(input.mask_json, input.fasta,output.fasta)

rule gather_consensus_sequences:
input:
composition = rules.files.params.composition,
fasta = expand(rules.curate_sequences.output.fasta, barcode=config[KEY_BARCODES])
fasta = expand(rules.mask_consensus_sequences.output.fasta, barcode=config[KEY_BARCODES])
params:
publish_dir = os.path.join(config[KEY_OUTDIR],"published_data")
output:
Expand All @@ -157,8 +169,6 @@ rule gather_consensus_sequences:
# also header now needs hap parsing & hap->CNS mapping
gather_fasta_files(input.composition, config[KEY_BARCODES_CSV], input.fasta,config[KEY_ALL_METADATA],config[KEY_RUNNAME], output[0],params.publish_dir,config)



rule generate_report:
input:
consensus_seqs = rules.gather_consensus_sequences.output.fasta,
Expand All @@ -177,6 +187,8 @@ rule generate_report:
config_loaded = yaml.safe_load(f)
with open(input.cns_yaml, 'r') as f:
cns_config_loaded = yaml.safe_load(f)

#var dict now has total on it- so can infer from var dict which sites have been masked out
make_sample_report(output.html,
input.variation_info,
input.consensus_seqs,
Expand Down
Loading

0 comments on commit 5742323

Please sign in to comment.