Skip to content

Commit

Permalink
add requirements and fix minor bug in gpt2 example
Browse files Browse the repository at this point in the history
  • Loading branch information
leruis committed Feb 2, 2024
1 parent c01fa48 commit fce30df
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/gpt2_influence/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, _, valid_loader = get_loaders()
model = construct_model().to(device)
model = construct_model()[0].to(device)


def train(
Expand Down Expand Up @@ -100,7 +100,7 @@ def model_evaluate(model: nn.Module, loader: torch.utils.data.DataLoader) -> flo
start_time = time.time()

set_seed(i)
model = construct_model()
model, _ = construct_model()

train(
model=model,
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ torch
einops

pyyaml
transformers
datasets
accelerate

0 comments on commit fce30df

Please sign in to comment.