diff --git a/docs/source/bert.rst b/docs/source/bert.rst index 59ed77f..b37d8fc 100644 --- a/docs/source/bert.rst +++ b/docs/source/bert.rst @@ -63,7 +63,7 @@ to fit our API signatures. We slightly redefine the :code:`forward` function so that we can pass in the inputs (:code:`input_ids`, etc.) as positional arguments instead of as keyword arguments. -For data loading, we adapt the code from Hugging Face example: +For data loading, we adapt the code from the HuggingFace example: .. raw:: html @@ -132,7 +132,7 @@ For data loading, we adapt the code from Hugging Face example: # NOTE: CHANGE THIS IF YOU WANT TO RUN ON FULL DATASET TRAIN_SET_SIZE = 5_000 - VAL_SET_SIZE = 1_00 + VAL_SET_SIZE = 10 def init_loaders(batch_size=16): ds_train = get_dataset('train') @@ -180,38 +180,59 @@ The model output function is implemented as follows: .. code-block:: python - def get_output(func_model, - weights: Iterable[Tensor], - buffers: Iterable[Tensor], - input_id: Tensor, - token_type_id: Tensor, - attention_mask: Tensor, - label: Tensor, - ) -> Tensor: - logits = func_model(weights, buffers, input_id.unsqueeze(0), - token_type_id.unsqueeze(0), - attention_mask.unsqueeze(0)) + def get_output( + model, + weights: Iterable[Tensor], + buffers: Iterable[Tensor], + input_id: Tensor, + token_type_id: Tensor, + attention_mask: Tensor, + label: Tensor, + ) -> Tensor: + kw_inputs = { + "input_ids": input_id.unsqueeze(0), + "token_type_ids": token_type_id.unsqueeze(0), + "attention_mask": attention_mask.unsqueeze(0), + } + + logits = ch.func.functional_call( + model, (weights, buffers), args=(), kwargs=kw_inputs + ) bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False) logits_correct = logits[bindex, label.unsqueeze(0)] cloned_logits = logits.clone() - cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf).to(logits.device) + cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor( + -ch.inf, device=logits.device, dtype=logits.dtype + ) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return margins.sum() -The implementation is identical to the standard classification example in :ref:`MODELOUTPUT tutorial`, -except here the signature of the method and the :code:`func_model` is slightly different -as the language model takes in three inputs instead of just one. +The implementation is identical to the standard classification example in +:ref:`MODELOUTPUT tutorial`, except here the signature of the method and the +:code:`func_model` is slightly different as the language model takes in three +inputs instead of just one. Similarly, the gradient function is implemented as follows: .. code-block:: python - def get_out_to_loss_grad(self, func_model, weights, buffers, batch: Iterable[Tensor]) -> Tensor: + def get_out_to_loss_grad( + self, model, weights, buffers, batch: Iterable[Tensor] + ) -> Tensor: input_ids, token_type_ids, attention_mask, labels = batch - logits = func_model(weights, buffers, input_ids, token_type_ids, attention_mask) - ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels] + kw_inputs = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": attention_mask, + } + logits = ch.func.functional_call( + model, (weights, buffers), args=(), kwargs=kw_inputs + ) + ps = self.softmax(logits / self.loss_temperature)[ + ch.arange(logits.size(0)), labels + ] return (1 - ps).clone().detach().unsqueeze(-1) Putting it together @@ -221,12 +242,14 @@ Using the above :code:`TextClassificationModelOutput` implementation, we can com .. code-block:: python - traker = TRAKer(model=model, - task=TextClassificationModelOutput, # you can also just pass in "text_classification" - train_set_size=TRAIN_SET_SIZE, - save_dir=args.out, - device=device, - proj_dim=1024) + traker = TRAKer( + model=model, + task=TextClassificationModelOutput, # you can also just pass in "text_classification" + train_set_size=TRAIN_SET_SIZE, + save_dir=SAVE_DIR, + device=DEVICE, + proj_dim=1024, + ) def process_batch(batch): return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels'] @@ -235,18 +258,21 @@ Using the above :code:`TextClassificationModelOutput` implementation, we can com for batch in tqdm(loader_train, desc='Featurizing..'): # process batch into compatible form for TRAKer TextClassificationModelOutput batch = process_batch(batch) - batch = [x.cuda() for x in batch] + batch = [x.to(DEVICE) for x in batch] traker.featurize(batch=batch, num_samples=batch[0].shape[0]) traker.finalize_features() - traker.start_scoring_checkpoint(model.state_dict(), model_id=0, num_targets=VAL_SET_SIZE) + traker.start_scoring_checkpoint(exp_name='qnli', + checkpoint=model.state_dict(), + model_id=0, + num_targets=VAL_SET_SIZE) for batch in tqdm(loader_val, desc='Scoring..'): batch = process_batch(batch) batch = [x.cuda() for x in batch] traker.score(batch=batch, num_samples=batch[0].shape[0]) - scores = traker.finalize_scores() + scores = traker.finalize_scores(exp_name='qnli') We use :code:`process_batch` to transform the batch from dictionary (which is the form used by Hugging Face dataloaders) to a tuple. @@ -256,4 +282,5 @@ That's all! You can find this tutorial as a complete script in `here