Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Finetuning] Add scripts for Genomics Long Range Benchmark Finetuning #46

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
54 changes: 54 additions & 0 deletions finetuning_glrb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Fine-Tuning DNA Models on the Genomics Long Range Benchmark 🧬

This folder contains the necessary scripts and configurations to fine-tune DNA models on the [Genomics Long Range Benchmark](https://huggingface.co/datasets/InstaDeepAI/genomics-long-range-benchmark).

DNA Models are loaded from the Hugging-Face Hub 🤗.

## Getting Started

To fine-tune a model, execute the `finetune.sh` script. The script runs the `main.py` script with various command-line arguments that configure the fine-tuning process. Below is a description of each argument used.

**`--task`**: Choose one of the predefined variant effects. Options include: `"variant_effect_causal_eqtl"`, `"variant_effect_pathogenic_clinvar"`, `"variant_effect_pathogenic_omim"`, `"cage_prediction"`, `"bulk_rna_expression"`, `"chromatin_features_histone_marks"`, `"chromatin_features_dna_accessibility"`, `"regulatory_element_promoter"`, `"regulatory_element_enhancer"`.

**`--seq_len`**: Specifies the sequence length in base pairs (bp).

**`--model_name`**: Name of the pre-trained model to fine-tune (on the HF hub).

**`--bp_per_token`**: Defines the number of base pairs per token used in the tokenization process of the model.

**`--save_dir`**: Directory where the checkpoints and logs will be saved.

**`--wandb_api_key`**: API key for Weights & Biases logging.

**`--name_wb`**: Name for the Weights & Biases run.

**`--train_batch_size`**: Defines the batch size for training.

**`--test_batch_size`**: Defines the batch size for testing/validation.

**`--num_workers`**: Number of workers to use for data loading.

**`--rcps`**: Indicates whether to use RCPS when extracting embeddings.

**`--num_epochs`**: Specifies the number of epochs to train.

**`--precision`**: Choose the precision. Options include: `"transformer-engine"`, `"transformer-engine-float16"`, `"16-true"`, `"16-mixed"`, `"bf16-true"`, `"bf16-mixed"`, `"32-true"`, `"64-true"`.

**`--accumulate_grad_batches`**: Number of batches for which to accumulate gradients accross devices.

**`--learning_rate`**: Specifies the learning rate for the optimizer.

**`--log_interval`**: Interval (in steps) at which to log training metrics and run a validation step.

**`--train_ratio`**: Specifies the ratio of the dataset to use for training.

**`--eval_ratio`**: Specifies the ratio of the dataset to use for evaluation.



### Running the Script

To start finetuning, first make sure that you have modified the `finetune.sh` script with the correct parameters for your task. Then, simply run:

```bash
bash finetune.sh
Empty file added finetuning_glrb/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions finetuning_glrb/finetune.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
python finetuning_glrb/main.py \
--task "bulk_rna_expression" \
--seq_len 12000 \
--model_name "your_model_name_from_the_hub" \
--bp_per_token TBD \
--save_dir "output/" \
--wandb_api_key "your_wandb_api_key" \
--name_wb "your_wandb_run_name" \
--train_batch_size 4 \
--test_batch_size 4 \
--rcps true \
--num_workers 6 \
--num_epochs 1 \
--precision "16-mixed" \
--learning_rate "3e-5" \
--log_interval 512 \
--accumulate_grad_batches 128 \
--train_ratio 1.0 \
--eval_ratio 1.0

##Examples

## Caduceus-PS
#task=bulk_rna_expression
#seq_len=131000
#bp_per_token=1
#model_name="kuleshov-group/caduceus-ps_seqlen-131k_d_model-256_n_layer-16"
#rcps=true

## NTv2
#task=regulatory_element_promoter
#seq_len=12288 # 2048 (seq len) * 6 (kmers)
#bp_per_token=6
#model_name="InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
#delete the rcps flag (it is not a RC-equivarient model)
Loading