Skip to content

Commit

Permalink
fix multiple frame injection synchronization
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamBelfki3 committed Dec 5, 2024
1 parent 4ca6718 commit bd80a5c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
3 changes: 1 addition & 2 deletions src/nnsight/tracing/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def __call__(self, graph: Graph) -> None:
for key, value in frame.f_locals.items():
if isinstance(value, Proxy) and value.node.done:
frame.f_locals[key] = value.value

ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(frame), 0)
ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(frame), 0)

except StopProtocol.StopException:

Expand Down
6 changes: 3 additions & 3 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_multi_token_generation(vllm_gpt2, MSG_prompt: str):
logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()

assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [" New", " York", " City"]
assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits]) == [" New", " York", " City"]


""" def test_max_token_generation(vllm_gpt2, ET_prompt: str):
Expand Down Expand Up @@ -182,10 +182,10 @@ def test_mutli_token_generation_with_intervention(tp, vllm_gpt2, MSG_prompt: str
logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()

assert [torch.all(hs == 0) for hs in hs_list.value] == [False, False, True, False, False]
assert [torch.all(hs == 0) for hs in hs_list] == [False, False, True, False, False]

if tp == 1:
assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [' New', ' York', '\n', '\n', 'The']
assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits]) == [' New', ' York', '\n', '\n', 'The']


""" def test_multi_referenced_module(vllm_gpt2, ET_prompt: str):
Expand Down

0 comments on commit bd80a5c

Please sign in to comment.