From b3cba48fa0bc8b2fc05a9050c48611a0ab1f766c Mon Sep 17 00:00:00 2001 From: cody-mar10 Date: Wed, 9 Oct 2024 14:00:41 -0500 Subject: [PATCH] initial commit --- examples/finetuning.ipynb | 517 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 517 insertions(+) create mode 100644 examples/finetuning.ipynb diff --git a/examples/finetuning.ipynb b/examples/finetuning.ipynb new file mode 100644 index 0000000..f00ac32 --- /dev/null +++ b/examples/finetuning.ipynb @@ -0,0 +1,517 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note: this notebook will only work for pst versions >1.1.0**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Finetuning with a new genome-level objective\n", + "The vPST was pretrained with a triplet loss objective, evaluating the genome embeddings.\n", + "\n", + "If you want to apply the vPST to a new objective (transfer learning), then you need to subclass the `BaseProteinSetTransformer` module class and update the following methods:\n", + "\n", + "1. `forward` code needed to handle a minibatch and compute the loss\n", + "2. `setup_objective` code needed to create a callable that computes the loss directly. This code is called upon initialization of the model, and the `forward` method calls the `.criterion` callable that is returned by this method.\n", + "\n", + "Additionally, if the loss function maintains state (such as the margin and scaling values of a triplet loss objective), then you can create a subclass of the `BaseModelConfig` with the loss field using a custom subclass of the `BaseLossConfig` that specifies the name and default values of stateful parameters needed by the loss function. This is only necessary for tunable hyperparameters of the loss function, NOT just any arguments needed to setup the loss function callable.\n", + "\n", + "-----\n", + "\n", + "Let's look at an example where we want to predict some random binary feature about the genomes in the sample dataset provided. For demonstration purposes, we will suppose that we have some tunable weight required for the loss function." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 111\n" + ] + }, + { + "data": { + "text/plain": [ + "111" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from pst import BaseProteinSetTransformer as BasePST\n", + "from pst import GenomeDataModule, BaseLossConfig, BaseModelConfig, GenomeGraphBatch\n", + "\n", + "import lightning as L\n", + "import torch\n", + "from pydantic import Field\n", + "\n", + "L.seed_everything(111)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since we are changing the objective of our new model that is derived from a pretrained PST, we need to define:\n", + "\n", + "1. A custom loss config model that subclasses `BaseLossConfig` IF the loss function requires a tunable state\n", + "2. A custom model config model is a a subclass of `BaseModelConfig` if any subfields need to be changed. The fields of this config model are available to the class through the `self.config` attribute.\n", + "3. A custom loss `torch.nn.Module` or function that computes the loss given the outputs of the model's forward pass and any expected targets, if any" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomLossConfig(BaseLossConfig):\n", + " tunable_weight: float = Field(0.5, ge=0.0, le=1.0, description=\"some tunable weight\")\n", + "\n", + "class CustomModelConfig(BaseModelConfig):\n", + " loss: CustomLossConfig\n", + "\n", + "class CustomLossFn(torch.nn.Module):\n", + " def __init__(self, weight: float):\n", + " super().__init__()\n", + " self.weight = weight # just an example, idk why you would use this\n", + " self.fn = torch.nn.BCEWithLogitsLoss()\n", + "\n", + " def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:\n", + " loss = self.fn(y_pred, y_true) * self.weight\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now you can create a subclass of the `BaseProteinSetTransformer` model that redefines 3 methods:\n", + "\n", + "1. In the `__init__` method, add any new layers and other attributes that are required by the model. Generally, there should only be 1 argument to the `__init__` method called \"config\". Thus, any additional attributes should be added as fields to a custom-defined config model.\n", + "2. `setup_objective` returns a callable that is used to compute the loss during a forward pass. It receives all values of the `self.config.loss` model as keyword arguments. The return value of this function is stored in `self.criterion`.\n", + "3. `forward` to define how data from a minibatch is handled to subsequently compute the loss using `self.criterion`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomGenomeLevelPST(BasePST[CustomModelConfig]): # <- this is optional to specify the config type here, but enables IDEs to provide better autocompletion\n", + " def __init__(self, config: CustomModelConfig):\n", + " super().__init__(config)\n", + "\n", + " # define new layers for new objective\n", + " self.pred_layer = torch.nn.Linear(self.config.out_dim, 1)\n", + "\n", + " def setup_objective(self, tunable_weight: float, **kwargs) -> CustomLossFn:\n", + " # notice how the var name is the same as in the CustomLossConfig -- those fields get passed\n", + " # as keyword arguments to this method\n", + " return CustomLossFn(tunable_weight)\n", + "\n", + " def forward(self, batch: GenomeGraphBatch, stage: str, **kwargs):\n", + " # add strand/pos embeddings\n", + " x_cat, _, _ = self.internal_embeddings(batch)\n", + "\n", + " pst_output, _ = self.databatch_forward(batch=batch, x=x_cat)\n", + "\n", + " y_pred = self.pred_layer(pst_output).squeeze()\n", + " y_true = batch.y\n", + "\n", + " loss = self.criterion(y_pred, y_true)\n", + "\n", + " self.log_loss(loss, batch.num_proteins.numel(), stage)\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have a custom model defined, let's see an extremely trivial example. In the sample dataset provided, there are 8 genomes that we will randomly generate a binary label for.\n", + "\n", + "Then we will use our model's loss function (which is primarily just a binary cross entropy loss for binary classification) to train this model with help from `lightning.Trainer`.\n", + "\n", + "Let's start by loading the sample dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "ckptfile = \"pst-small_trained_model.ckpt\"\n", + "data_file = \"sample_dataset.graphfmt.h5\"\n", + "datamodule = GenomeDataModule.from_pretrained(\n", + " checkpoint_path=ckptfile, data_file=data_file, shuffle=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we want to add a `y` field to our dataset that contains our randomly generated labels. NOTE: We store this in a `y` field since our model's `forward` method refers to the `y` attribute of the minibatch object (`batch.y`)\n", + "\n", + "Here is how we can register new dataset attributes using the `GenomeDataModule.register_feature` method." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = datamodule.dataset\n", + "n_genomes = len(dataset)\n", + "# randomly generated genome level labels\n", + "y_true = (torch.rand(n_genomes) >= 0.5).float()\n", + "\n", + "datamodule.register_feature(\"y\", y_true, feature_level=\"genome\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then all you need to do is train your model!" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/scratch/ccmartin6/miniconda3/envs/pst/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", + "Loading `train_dataloader` to estimate number of stepping batches.\n", + "/scratch/ccmartin6/miniconda3/envs/pst/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.\n", + "\n", + " | Name | Type | Params\n", + "-------------------------------------------------------------\n", + "0 | positional_embedding | PositionalEmbedding | 81.9 K\n", + "1 | strand_embedding | Embedding | 80 \n", + "2 | model | SetTransformer | 5.3 M \n", + "3 | criterion | CustomLossFn | 0 \n", + "4 | pred_layer | Linear | 401 \n", + "-------------------------------------------------------------\n", + "5.4 M Trainable params\n", + "0 Non-trainable params\n", + "5.4 M Total params\n", + "21.566 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3dd4033c723446d1a08028bf01d9ec60", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | | 0/? [00:00= 0.5).float()\n", + "\n", + "# accuracy\n", + "(pred == batch.y).sum() / pred.size(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Finetuning with a new protein-level objective\n", + "The vPST was pretrained with genome-level objective. However, it internally computes contextualized protein embeddings using genome context.\n", + "\n", + "If you want to focus more on these protein embeddings rather than genome embeddings, such as for a protein prediction task or even pretraining a protein foundation model, then you need to create a subclass of the `BaseProteinSetTransformerEncoder` module and update the following methods:\n", + "\n", + "1. `forward` code needed to handle a minibatch and compute the loss\n", + "2. `setup_objective` code needed to create a callable that computes the loss directly. This code is called upon initialization of the model, and the `forward` method calls the `.criterion` callable that is returned by this method.\n", + "\n", + "Additionally, if the loss function maintains state (such as the margin and scaling values of a triplet loss objective), then you can create a subclass of the `BaseModelConfig` with the loss field using a custom subclass of the `BaseLossConfig` that specifies the name and default values of stateful parameters needed by the loss function. This is only necessary for tunable hyperparameters of the loss function, NOT just any arguments needed to setup the loss function callable.\n", + "\n", + "NOTE: This is pretty much identical as the genome-level objective change above. The ONLY difference is that you need to subclass a `BaseProteinSetTransformerEncoder` class instead of `BaseProteinSetTransformer`.\n", + "\n", + "-----\n", + "\n", + "Let's look at an example where we want to predict some random binary feature about the genomes in the sample dataset provided. For demonstration purposes, we will suppose that we have some tunable weight required for the loss function." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from pst import BaseProteinSetTransformerEncoder as BasePSTEncoder" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We are just reusing the loss function and custom model config defined in the genome-level demo to compute binary cross entropy loss for a randomly generated protein-level label." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomProteinLevelPST(BasePSTEncoder[CustomModelConfig]): # <- again note the optional config type hint here\n", + " def __init__(self, config: CustomModelConfig):\n", + " super().__init__(config)\n", + "\n", + " # define new layers for new objective\n", + " self.pred_layer = torch.nn.Linear(self.config.out_dim, 1)\n", + "\n", + " def setup_objective(self, tunable_weight: float, **kwargs) -> CustomLossFn:\n", + " return CustomLossFn(tunable_weight)\n", + " \n", + " def forward(self, batch: GenomeGraphBatch, stage: str, **kwargs):\n", + " # intentionally left this nearly identical to the previous example\n", + "\n", + " # add strand/pos embeddings\n", + " x_cat, _, _ = self.internal_embeddings(batch)\n", + "\n", + " pst_encoder_output, _, _ = self.databatch_forward(batch=batch, x=x_cat)\n", + "\n", + " y_pred = self.pred_layer(pst_encoder_output).squeeze()\n", + " y_true = batch.y\n", + "\n", + " loss = self.criterion(y_pred, y_true)\n", + "\n", + " self.log_loss(loss, batch.num_proteins.numel(), stage)\n", + "\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We already loaded the datamodule previously, so we just need to register the randomly created protein labels." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "n_proteins = dataset.data.shape[0]\n", + "y_true = (torch.rand(n_proteins) >= 0.5).float()\n", + "dataset.register_feature(\"y\", y_true, feature_level=\"protein\", overwrite_previously_registered=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then just train the model using the `lightning.Trainer`!" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "Loading `train_dataloader` to estimate number of stepping batches.\n", + "\n", + " | Name | Type | Params\n", + "---------------------------------------------------------------\n", + "0 | positional_embedding | PositionalEmbedding | 81.9 K\n", + "1 | strand_embedding | Embedding | 80 \n", + "2 | model | SetTransformerEncoder | 4.0 M \n", + "3 | criterion | CustomLossFn | 0 \n", + "4 | pred_layer | Linear | 401 \n", + "---------------------------------------------------------------\n", + "4.1 M Trainable params\n", + "0 Non-trainable params\n", + "4.1 M Total params\n", + "16.422 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "62fb8417dc6744659f2acc2f1ed98569", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | | 0/? [00:00= 0.5).float()\n", + "\n", + "# accuracy\n", + "(pred == batch.y).sum() / pred.size(0)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pst", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}