Skip to content

Commit

Permalink
Improve style.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed Aug 23, 2024
1 parent 7d288ed commit 6f3a500
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 56 deletions.
15 changes: 10 additions & 5 deletions CodonTransformer/CodonData.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,12 @@ def get_amino_acid_sequence(
table=codon_table, # Codon table to use for translation
)
).strip()

return protein_seq \
if not return_correct_seq \
else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol))

return (
protein_seq
if not return_correct_seq
else (protein_seq, is_correct_seq(dna_seq, protein_seq, stop_symbol))
)


def read_fasta_file(
Expand Down Expand Up @@ -399,7 +401,10 @@ def read_fasta_file(

# Translate DNA to protein sequence
protein, correct_seq = get_amino_acid_sequence(
dna, stop_symbol=STOP_SYMBOL, codon_table=codon_table, return_correct_seq=True
dna,
stop_symbol=STOP_SYMBOL,
codon_table=codon_table,
return_correct_seq=True,
)
description = str(record.description[: record.description.find("[")])
tokenized = get_merged_seq(protein, dna, seperator=STOP_SYMBOL)
Expand Down
2 changes: 0 additions & 2 deletions CodonTransformer/CodonEvaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from typing import List, Dict, Tuple
from tqdm import tqdm

from CodonTransformer.CodonUtils import AMINO2CODON_TYPE
from CodonTransformer.CodonData import build_amino2codon_skeleton, get_codon_frequencies


def get_CSI_weights(sequences: List[str]) -> Dict[str, float]:
Expand Down
25 changes: 5 additions & 20 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,22 +187,13 @@ def main(args):
help="Filename for the saved checkpoint",
)
parser.add_argument(
"--batch_size",
type=int,
default=6,
help="Batch size for training"
"--batch_size", type=int, default=6, help="Batch size for training"
)
parser.add_argument(
"--max_epochs",
type=int,
default=15,
help="Maximum number of epochs to train"
"--max_epochs", type=int, default=15, help="Maximum number of epochs to train"
)
parser.add_argument(
"--num_workers",
type=int,
default=5,
help="Number of workers for data loading"
"--num_workers", type=int, default=5, help="Number of workers for data loading"
)
parser.add_argument(
"--accumulate_grad_batches",
Expand All @@ -211,10 +202,7 @@ def main(args):
help="Number of batches to accumulate gradients",
)
parser.add_argument(
"--num_gpus",
type=int,
default=4,
help="Number of GPUs to use for training"
"--num_gpus", type=int, default=4, help="Number of GPUs to use for training"
)
parser.add_argument(
"--learning_rate",
Expand All @@ -235,10 +223,7 @@ def main(args):
help="Save checkpoint every N steps",
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="Random seed for reproducibility"
"--seed", type=int, default=123, help="Random seed for reproducibility"
)
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
args = parser.parse_args()
Expand Down
36 changes: 7 additions & 29 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,22 +199,13 @@ def main(args):
help="Directory where checkpoints will be saved",
)
parser.add_argument(
"--batch_size",
type=int,
default=6,
help="Batch size for training"
"--batch_size", type=int, default=6, help="Batch size for training"
)
parser.add_argument(
"--max_epochs",
type=int,
default=5,
help="Maximum number of epochs to train"
"--max_epochs", type=int, default=5, help="Maximum number of epochs to train"
)
parser.add_argument(
"--num_workers",
type=int,
default=5,
help="Number of workers for data loading"
"--num_workers", type=int, default=5, help="Number of workers for data loading"
)
parser.add_argument(
"--accumulate_grad_batches",
Expand All @@ -223,10 +214,7 @@ def main(args):
help="Number of batches to accumulate gradients",
)
parser.add_argument(
"--num_gpus",
type=int,
default=16,
help="Number of GPUs to use for training"
"--num_gpus", type=int, default=16, help="Number of GPUs to use for training"
)
parser.add_argument(
"--learning_rate",
Expand All @@ -241,21 +229,11 @@ def main(args):
help="Fraction of total steps to use for warmup",
)
parser.add_argument(
"--save_interval",
type=int,
default=5,
help="Save checkpoint every N epochs"
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="Random seed for reproducibility"
"--save_interval", type=int, default=5, help="Save checkpoint every N epochs"
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode"
"--seed", type=int, default=123, help="Random seed for reproducibility"
)
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
args = parser.parse_args()
main(args)

0 comments on commit 6f3a500

Please sign in to comment.