Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
kristian-georgiev committed Nov 2, 2023
1 parent 1f415aa commit 102edbf
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 40 deletions.
87 changes: 57 additions & 30 deletions docs/source/bert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ to fit our API signatures.
We slightly redefine the :code:`forward` function so that we can pass in the inputs (:code:`input_ids`, etc.) as positional arguments instead of as keyword arguments.

For data loading, we adapt the code from Hugging Face example:
For data loading, we adapt the code from the HuggingFace example:

.. raw:: html

Expand Down Expand Up @@ -132,7 +132,7 @@ For data loading, we adapt the code from Hugging Face example:
# NOTE: CHANGE THIS IF YOU WANT TO RUN ON FULL DATASET
TRAIN_SET_SIZE = 5_000
VAL_SET_SIZE = 1_00
VAL_SET_SIZE = 10
def init_loaders(batch_size=16):
ds_train = get_dataset('train')
Expand Down Expand Up @@ -180,38 +180,59 @@ The model output function is implemented as follows:

.. code-block:: python
def get_output(func_model,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
input_id: Tensor,
token_type_id: Tensor,
attention_mask: Tensor,
label: Tensor,
) -> Tensor:
logits = func_model(weights, buffers, input_id.unsqueeze(0),
token_type_id.unsqueeze(0),
attention_mask.unsqueeze(0))
def get_output(
model,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
input_id: Tensor,
token_type_id: Tensor,
attention_mask: Tensor,
label: Tensor,
) -> Tensor:
kw_inputs = {
"input_ids": input_id.unsqueeze(0),
"token_type_ids": token_type_id.unsqueeze(0),
"attention_mask": attention_mask.unsqueeze(0),
}
logits = ch.func.functional_call(
model, (weights, buffers), args=(), kwargs=kw_inputs
)
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
logits_correct = logits[bindex, label.unsqueeze(0)]
cloned_logits = logits.clone()
cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf).to(logits.device)
cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(
-ch.inf, device=logits.device, dtype=logits.dtype
)
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return margins.sum()
The implementation is identical to the standard classification example in :ref:`MODELOUTPUT tutorial`,
except here the signature of the method and the :code:`func_model` is slightly different
as the language model takes in three inputs instead of just one.
The implementation is identical to the standard classification example in
:ref:`MODELOUTPUT tutorial`, except here the signature of the method and the
:code:`func_model` is slightly different as the language model takes in three
inputs instead of just one.

Similarly, the gradient function is implemented as follows:

.. code-block:: python
def get_out_to_loss_grad(self, func_model, weights, buffers, batch: Iterable[Tensor]) -> Tensor:
def get_out_to_loss_grad(
self, model, weights, buffers, batch: Iterable[Tensor]
) -> Tensor:
input_ids, token_type_ids, attention_mask, labels = batch
logits = func_model(weights, buffers, input_ids, token_type_ids, attention_mask)
ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels]
kw_inputs = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}
logits = ch.func.functional_call(
model, (weights, buffers), args=(), kwargs=kw_inputs
)
ps = self.softmax(logits / self.loss_temperature)[
ch.arange(logits.size(0)), labels
]
return (1 - ps).clone().detach().unsqueeze(-1)
Putting it together
Expand All @@ -221,12 +242,14 @@ Using the above :code:`TextClassificationModelOutput` implementation, we can com

.. code-block:: python
traker = TRAKer(model=model,
task=TextClassificationModelOutput, # you can also just pass in "text_classification"
train_set_size=TRAIN_SET_SIZE,
save_dir=args.out,
device=device,
proj_dim=1024)
traker = TRAKer(
model=model,
task=TextClassificationModelOutput, # you can also just pass in "text_classification"
train_set_size=TRAIN_SET_SIZE,
save_dir=SAVE_DIR,
device=DEVICE,
proj_dim=1024,
)
def process_batch(batch):
return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels']
Expand All @@ -235,18 +258,21 @@ Using the above :code:`TextClassificationModelOutput` implementation, we can com
for batch in tqdm(loader_train, desc='Featurizing..'):
# process batch into compatible form for TRAKer TextClassificationModelOutput
batch = process_batch(batch)
batch = [x.cuda() for x in batch]
batch = [x.to(DEVICE) for x in batch]
traker.featurize(batch=batch, num_samples=batch[0].shape[0])
traker.finalize_features()
traker.start_scoring_checkpoint(model.state_dict(), model_id=0, num_targets=VAL_SET_SIZE)
traker.start_scoring_checkpoint(exp_name='qnli',
checkpoint=model.state_dict(),
model_id=0,
num_targets=VAL_SET_SIZE)
for batch in tqdm(loader_val, desc='Scoring..'):
batch = process_batch(batch)
batch = [x.cuda() for x in batch]
traker.score(batch=batch, num_samples=batch[0].shape[0])
scores = traker.finalize_scores()
scores = traker.finalize_scores(exp_name='qnli')
We use :code:`process_batch` to transform the batch from dictionary (which is the form used by Hugging Face dataloaders) to a tuple.

Expand All @@ -256,4 +282,5 @@ That's all! You can find this tutorial as a complete script in `here <https://gi
Extending to other tasks
----------------------------------

For a more involved example that is *not* classification, see :ref:`CLIP tutorial`.
For a more involved example that is *not* classification, see :ref:`CLIP
tutorial`.
19 changes: 12 additions & 7 deletions docs/source/clip.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Now we are ready to implement :meth:`.CLIPModelOutput.get_output`:
buffers: Iterable[Tensor],
image: Tensor,
label: Tensor):
# tailored for open_clip
# https://github.com/mlfoundations/open_clip/blob/fb72f4db1b17133befd6c67c9cf32a533b85a321/src/open_clip/model.py#L242-L245
clip_inputs = {"image": image.unsqueeze(0), "text": label.unsqueeze(0)}
image_embeddings, text_embeddings, _ = ch.func.functional_call(model,
(weights, buffers),
args=(),
Expand Down Expand Up @@ -116,24 +119,26 @@ Using the above :code:`CLIPModelOutput` implementation, we can compute
device=device,
proj_dim=1024)
traker.task.get_embeddings(model, loader_train, batch_size=...,
traker.task.get_embeddings(model, ds_train, batch_size=1, size=600, embedding_dim=1024,
preprocess_fn_img=lambda x: preprocess(x).to(device).unsqueeze(0),
preprocess_fn_txt=lambda x: tokenizer(x[0]).to(device))
traker.load_checkpoint(model.state_dict(), model_id=0)
for batch in tqdm(loader_train, desc='Featurizing...'):
batch = [x.cuda() for x in batch]
traker.featurize(batch=batch, num_samples=batch[0].shape[0])
for (img, captions) in tqdm(loader_train, desc='Featurizing...'):
x = preprocess(img).to('cuda').unsqueeze(0)
y = tokenizer(captions).to('cuda')
traker.featurize(batch=(x, y), num_samples=x.shape[0])
traker.finalize_features()
traker.start_scoring_checkpoint(exp_name='clip_example',
checkpoint=model.state_dict(),
model_id=0,
num_targets=VAL_SET_SIZE)
for batch in tqdm(loader_val, desc='Scoring...'):
batch = [x.cuda() for x in batch]
traker.score(batch=batch, num_samples=batch[0].shape[0])
for (img, captions) in tqdm(loader_val, desc='Scoring...'):
x = preprocess(img).to('cuda').unsqueeze(0)
y = tokenizer(captions).to('cuda')
traker.score(batch=(x, y), num_samples=x.shape[0])
scores = traker.finalize_scores(exp_name='clip_example')
Expand Down
3 changes: 0 additions & 3 deletions docs/source/modeloutput.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@ the :code:`task` when instantiating :class:`.TRAKer`:
def get_output(...):
# Implement
def forward(...):
# Implement
def get_out_to_loss_grad(...):
# Implement
Expand Down

0 comments on commit 102edbf

Please sign in to comment.