Skip to content

Commit

Permalink
feat(test_train_sae): format with ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
Frankstein73 committed Oct 30, 2024
1 parent 4cc893d commit c5f9dbf
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
1 change: 1 addition & 0 deletions tests/intergration/test_attributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class TestModule(HookedRootModule):
__test__ = False

def __init__(self):
super().__init__()
self.W_1 = nn.Parameter(torch.tensor([[1.0, 2.0]]))
Expand Down
9 changes: 4 additions & 5 deletions tests/intergration/test_train_sae.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from transformer_lens import HookedTransformer, HookedTransformerConfig
import torch
from einops import rearrange
from torch.optim import Adam
from transformer_lens import HookedTransformer, HookedTransformerConfig

from lm_saes.config import SAEConfig
from lm_saes.sae import SparseAutoEncoder
from einops import rearrange


def test_train_sae():
Expand Down Expand Up @@ -41,9 +42,7 @@ def test_train_sae():
### Get activations ###
tokens = torch.randint(0, 50, (batch_size, 10))
with torch.no_grad():
_, cache = model.run_with_cache_until(
tokens, names_filter=hook_point, until=hook_point
)
_, cache = model.run_with_cache_until(tokens, names_filter=hook_point, until=hook_point)
batch = {
hook_point: rearrange(
cache[hook_point].to(dtype=dtype, device=device),
Expand Down

0 comments on commit c5f9dbf

Please sign in to comment.