diff --git a/snipit/command.py b/snipit/command.py index 7b834d6..c31eff0 100644 --- a/snipit/command.py +++ b/snipit/command.py @@ -73,7 +73,7 @@ def main(sysargs = sys.argv[1:]): else: args = parser.parse_args(sysargs) - num_seqs,ref_input,record_ids,length = sfunks.qc_alignment(args.alignment,args.reference,args.cds_mode,cwd) + num_seqs,ref_input,record_ids,length = sfunks.qc_alignment(args.alignment,args.reference,args.cds_mode,args.sequence_type,cwd) if args.reference: diff --git a/snipit/scripts/snp_functions.py b/snipit/scripts/snp_functions.py index 673ff3f..889241d 100644 --- a/snipit/scripts/snp_functions.py +++ b/snipit/scripts/snp_functions.py @@ -10,6 +10,10 @@ import math from itertools import groupby, count from collections import OrderedDict +from enum import Enum +import warnings +warnings.filterwarnings('ignore') + # imports from other modules from Bio import SeqIO @@ -19,6 +23,7 @@ import matplotlib.patches as patches from matplotlib.patches import Polygon + colour_list = ["lightgrey","white"] colour_cycle = cycle(colour_list) END_FORMATTING = '\033[0m' @@ -58,7 +63,7 @@ def check_ref(recombi_mode): sys.exit(-1) -def qc_alignment(alignment,reference,cds_mode,cwd): +def qc_alignment(alignment,reference,cds_mode,sequence_type,cwd): lengths = [] lengths_info = [] num_seqs = 0 @@ -102,6 +107,8 @@ def qc_alignment(alignment,reference,cds_mode,cwd): sys.stderr.write(red("Error: CDS mode flag used but alignment length not a multiple of 3.\n")) sys.exit(-1) + print(green(f"Note:") + f" assuming the alignment provided is of type {sequence_type}. If this is not the case, change input --sequence-type") + return num_seqs,ref_input,record_ids,lengths[0] def reference_qc(reference, record_ids,cwd): @@ -261,10 +268,13 @@ def find_snps(reference_seq,input_seqs,show_indels): return snp_dict,record_snps,len(var_counter) -def find_ambiguities(alignment,snp_dict): - +def find_ambiguities(alignment, snp_dict,sequence_type): + if sequence_type == "nt": + amb = NT_AMBIG + if sequence_type == "aa": + amb = AA_AMBIG snp_sites = collections.defaultdict(list)