From c5f9dbfd126d457086c386afc6cf3ff05be3c8b8 Mon Sep 17 00:00:00 2001 From: Frankstein73 <1053905229@qq.com> Date: Wed, 30 Oct 2024 22:17:32 +0800 Subject: [PATCH] feat(test_train_sae): format with ruff --- tests/intergration/test_attributor.py | 1 + tests/intergration/test_train_sae.py | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/intergration/test_attributor.py b/tests/intergration/test_attributor.py index 7856922..25790a1 100644 --- a/tests/intergration/test_attributor.py +++ b/tests/intergration/test_attributor.py @@ -9,6 +9,7 @@ class TestModule(HookedRootModule): __test__ = False + def __init__(self): super().__init__() self.W_1 = nn.Parameter(torch.tensor([[1.0, 2.0]])) diff --git a/tests/intergration/test_train_sae.py b/tests/intergration/test_train_sae.py index 07df386..8ab3168 100644 --- a/tests/intergration/test_train_sae.py +++ b/tests/intergration/test_train_sae.py @@ -1,9 +1,10 @@ -from transformer_lens import HookedTransformer, HookedTransformerConfig import torch +from einops import rearrange from torch.optim import Adam +from transformer_lens import HookedTransformer, HookedTransformerConfig + from lm_saes.config import SAEConfig from lm_saes.sae import SparseAutoEncoder -from einops import rearrange def test_train_sae(): @@ -41,9 +42,7 @@ def test_train_sae(): ### Get activations ### tokens = torch.randint(0, 50, (batch_size, 10)) with torch.no_grad(): - _, cache = model.run_with_cache_until( - tokens, names_filter=hook_point, until=hook_point - ) + _, cache = model.run_with_cache_until(tokens, names_filter=hook_point, until=hook_point) batch = { hook_point: rearrange( cache[hook_point].to(dtype=dtype, device=device),