Skip to content

Commit

Permalink
address #15
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 3, 2024
1 parent 34c99a6 commit ec8b911
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
4 changes: 2 additions & 2 deletions self_rewarding_lm_pytorch/dpo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from pathlib import Path
from copy import deepcopy
from functools import cache
from functools import lru_cache
from collections import namedtuple
from dataclasses import dataclass

Expand Down Expand Up @@ -55,7 +55,7 @@ def cycle(dl):
for batch in dl:
yield batch

@cache
@lru_cache(maxsize = None)
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1

Expand Down
18 changes: 16 additions & 2 deletions self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def __init__(
model: Module,
prompt_dataset: Dataset,
num_preference_pairs: int,
accelerator: Accelerator,
tokenizer_encode: Callable[[str], TensorType['seq', int]],
tokenizer_decode: Callable[[TensorType['seq', int]], str],
batch_size: int = 16,
Expand Down Expand Up @@ -475,6 +476,12 @@ def __init__(
self.prompt_len_memmap = open_memmap(str(self.prompt_len_memmap_path), dtype = 'int', mode = 'w+', shape = (num_preference_pairs,))
self.self_reward_memmap_file = open_memmap(str(self.self_reward_mmemap_path), dtype = 'float32', mode = 'w+', shape = (num_preference_pairs, 2))

self.accelerator = accelerator

@property
def device(self):
return self.accelerator.device

def generate_reward(
self,
prompt: str,
Expand All @@ -496,8 +503,11 @@ def generate_reward(

reward_prompt = repeat(reward_prompt, 'n -> b n', b = self.num_evals_to_average)

reward_prompt = reward_prompt.to(self.device)
model = self.model.to(self.device)

reward_responses = sample(
self.model,
model,
prompts = reward_prompt,
seq_len = self.generate_reward_max_seq_len,
temperature = self.eval_temperature,
Expand Down Expand Up @@ -536,13 +546,16 @@ def forward(self) -> DPODataset:

responses = []

model = self.model.to(self.device)

for prompt, prompt_tensor in zip(prompts, prompt_tensors):

prompt_len = prompt_tensor.shape[-1]
repeated_prompt_tensor = repeat(prompt_tensor, 'n -> r n', r = self.num_candidate_responses)
repeated_prompt_tensor = repeated_prompt_tensor.to(self.device)

candidate_tensor_responses = sample(
self.model,
model,
prompts = repeated_prompt_tensor,
seq_len = self.preference_max_seq_len,
temperature = self.gen_temperature,
Expand Down Expand Up @@ -802,6 +815,7 @@ def __init__(
eval_temperature = config.eval_temperature,
eval_filter_fn = config.eval_filter_fn,
eval_filter_kwargs = config.eval_filter_kwargs,
accelerator = self.accelerator,
**config.reward_generator_kwargs
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'self-rewarding-lm-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.4',
version = '0.2.5',
license='MIT',
description = 'Self Rewarding LM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit ec8b911

Please sign in to comment.