diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index 3fc77bc..102d7b2 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -345,10 +345,12 @@ def get_amino_acid_sequence( table=codon_table, # Codon table to use for translation ) ).strip() - - return protein_seq \ - if not return_correct_seq \ - else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol)) + + return ( + protein_seq + if not return_correct_seq + else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol)) + ) def read_fasta_file( @@ -399,7 +401,10 @@ def read_fasta_file( # Translate DNA to protein sequence protein, correct_seq = get_amino_acid_sequence( - dna, stop_symbol=STOP_SYMBOL, codon_table=codon_table, return_correct_seq=True + dna, + stop_symbol=STOP_SYMBOL, + codon_table=codon_table, + return_correct_seq=True, ) description = str(record.description[: record.description.find("[")]) tokenized = get_merged_seq(protein, dna, seperator=STOP_SYMBOL) diff --git a/CodonTransformer/CodonEvaluation.py b/CodonTransformer/CodonEvaluation.py index a4db585..e94423a 100644 --- a/CodonTransformer/CodonEvaluation.py +++ b/CodonTransformer/CodonEvaluation.py @@ -11,8 +11,6 @@ from typing import List, Dict, Tuple from tqdm import tqdm -from CodonTransformer.CodonUtils import AMINO2CODON_TYPE -from CodonTransformer.CodonData import build_amino2codon_skeleton, get_codon_frequencies def get_CSI_weights(sequences: List[str]) -> Dict[str, float]: diff --git a/finetune.py b/finetune.py index d11def2..41b3a19 100644 --- a/finetune.py +++ b/finetune.py @@ -187,22 +187,13 @@ def main(args): help="Filename for the saved checkpoint", ) parser.add_argument( - "--batch_size", - type=int, - default=6, - help="Batch size for training" + "--batch_size", type=int, default=6, help="Batch size for training" ) parser.add_argument( - "--max_epochs", - type=int, - default=15, - help="Maximum number of epochs to train" + "--max_epochs", type=int, default=15, help="Maximum number of epochs to train" ) parser.add_argument( - "--num_workers", - type=int, - default=5, - help="Number of workers for data loading" + "--num_workers", type=int, default=5, help="Number of workers for data loading" ) parser.add_argument( "--accumulate_grad_batches", @@ -211,10 +202,7 @@ def main(args): help="Number of batches to accumulate gradients", ) parser.add_argument( - "--num_gpus", - type=int, - default=4, - help="Number of GPUs to use for training" + "--num_gpus", type=int, default=4, help="Number of GPUs to use for training" ) parser.add_argument( "--learning_rate", @@ -235,10 +223,7 @@ def main(args): help="Save checkpoint every N steps", ) parser.add_argument( - "--seed", - type=int, - default=123, - help="Random seed for reproducibility" + "--seed", type=int, default=123, help="Random seed for reproducibility" ) parser.add_argument("--debug", action="store_true", help="Enable debug mode") args = parser.parse_args() diff --git a/pretrain.py b/pretrain.py index b472872..bafd738 100644 --- a/pretrain.py +++ b/pretrain.py @@ -199,22 +199,13 @@ def main(args): help="Directory where checkpoints will be saved", ) parser.add_argument( - "--batch_size", - type=int, - default=6, - help="Batch size for training" + "--batch_size", type=int, default=6, help="Batch size for training" ) parser.add_argument( - "--max_epochs", - type=int, - default=5, - help="Maximum number of epochs to train" + "--max_epochs", type=int, default=5, help="Maximum number of epochs to train" ) parser.add_argument( - "--num_workers", - type=int, - default=5, - help="Number of workers for data loading" + "--num_workers", type=int, default=5, help="Number of workers for data loading" ) parser.add_argument( "--accumulate_grad_batches", @@ -223,10 +214,7 @@ def main(args): help="Number of batches to accumulate gradients", ) parser.add_argument( - "--num_gpus", - type=int, - default=16, - help="Number of GPUs to use for training" + "--num_gpus", type=int, default=16, help="Number of GPUs to use for training" ) parser.add_argument( "--learning_rate", @@ -241,21 +229,11 @@ def main(args): help="Fraction of total steps to use for warmup", ) parser.add_argument( - "--save_interval", - type=int, - default=5, - help="Save checkpoint every N epochs" - ) - parser.add_argument( - "--seed", - type=int, - default=123, - help="Random seed for reproducibility" + "--save_interval", type=int, default=5, help="Save checkpoint every N epochs" ) parser.add_argument( - "--debug", - action="store_true", - help="Enable debug mode" + "--seed", type=int, default=123, help="Random seed for reproducibility" ) + parser.add_argument("--debug", action="store_true", help="Enable debug mode") args = parser.parse_args() main(args)