diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 0d65f9c9..5b0ad611 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -50,3 +50,5 @@ jobs: run: pdm install - name: Type check run: pdm run mypy . + - name: Unit tests + run: pdm run pytest ./tests diff --git a/tests/test_activation_source.py b/tests/test_activation_source.py deleted file mode 100644 index 78716093..00000000 --- a/tests/test_activation_source.py +++ /dev/null @@ -1,52 +0,0 @@ -from lm_saes.activation.activation_source import TokenActivationSource -from lm_saes.activation.token_source import TokenSource - -from datasets import load_dataset -from transformer_lens import HookedTransformer - -import torch - -import pytest - -@pytest.fixture -def dataset(): - return load_dataset("Skylion007/openwebtext", split="train") - -@pytest.fixture -def dataloader(dataset): - return torch.utils.data.DataLoader(dataset, batch_size=32) - -@pytest.fixture -def model(): - return HookedTransformer.from_pretrained('gpt2') - -def test_token_source(dataloader, model): - token_source = TokenSource( - dataloader=dataloader, - model=model, - is_dataset_tokenized=False, - seq_len=128, - device="cuda", - ) - tokens = token_source.next(4) - print(tokens.detach().cpu().float().numpy()) - -def test_token_activation_source(dataloader, model): - token_source = TokenSource( - dataloader=dataloader, - model=model, - is_dataset_tokenized=False, - seq_len=128, - device="cuda", - ) - act_source = TokenActivationSource( - token_source=token_source, - model=model, - token_batch_size=32, - hook_point="activation", - device="cuda", - dtype=torch.float32, - ) - act = act_source.next(4) - print(act["activation"].detach().cpu().float().numpy()) - diff --git a/tests/unit/test_example.py b/tests/unit/test_example.py new file mode 100644 index 00000000..bef80200 --- /dev/null +++ b/tests/unit/test_example.py @@ -0,0 +1,9 @@ +import pytest + + +def func(x): + return x + 1 + + +def test_answer(): + assert func(4) == 5