diff --git a/src/lm_saes/analysis/sample_feature_activations.py b/src/lm_saes/analysis/sample_feature_activations.py index 7308c00..3f95ddc 100644 --- a/src/lm_saes/analysis/sample_feature_activations.py +++ b/src/lm_saes/analysis/sample_feature_activations.py @@ -65,8 +65,19 @@ def sample_feature_activations( _, cache = model.run_with_cache_until(batch, names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], until=cfg.sae.hook_point_out) activation_in, activation_out = cache[cfg.sae.hook_point_in], cache[cfg.sae.hook_point_out] + filter_mask = torch.logical_or( + batch.eq(model.tokenizer.eos_token_id), + batch.eq(model.tokenizer.pad_token_id) + ) + filter_mask = torch.logical_or( + filter_mask, + batch.eq(model.tokenizer.bos_token_id) + ) + feature_acts = sae.encode(activation_in, label=activation_out)[..., start_index: end_index] + feature_acts[filter_mask] = 0 + act_times += feature_acts.gt(0.0).sum(dim=[0, 1]) for name in cfg.subsample.keys(): diff --git a/src/lm_saes/evals.py b/src/lm_saes/evals.py index 169a22c..75a2e42 100644 --- a/src/lm_saes/evals.py +++ b/src/lm_saes/evals.py @@ -23,6 +23,8 @@ def run_evals( ): ### Evals eval_tokens = activation_store.next_tokens(cfg.act_store.dataset.store_batch_size) + + assert eval_tokens is not None, "Activation store is empty" # Get Reconstruction Score losses_df = recons_loss_batched( @@ -41,13 +43,15 @@ def run_evals( # get cache _, cache = model.run_with_cache_until( eval_tokens, - prepend_bos=False, names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], until=cfg.sae.hook_point_out, ) + filter_mask = torch.logical_and(eval_tokens.ne(model.tokenizer.eos_token_id), eval_tokens.ne(model.tokenizer.pad_token_id)) + filter_mask = torch.logical_and(filter_mask, eval_tokens.ne(model.tokenizer.bos_token_id)) + # get act - original_act_in, original_act_out = cache[cfg.sae.hook_point_in], cache[cfg.sae.hook_point_out] + original_act_in, original_act_out = cache[cfg.sae.hook_point_in][filter_mask], cache[cfg.sae.hook_point_out][filter_mask] feature_acts = sae.encode(original_act_in, label=original_act_out) reconstructed = sae.decode(feature_acts) @@ -144,7 +148,6 @@ def get_recons_loss( _, cache = model.run_with_cache_until( batch_tokens, - prepend_bos=False, names_filter=[cfg.sae.hook_point_in, cfg.sae.hook_point_out], until=cfg.sae.hook_point_out, )