Skip to content

Commit

Permalink
feat(ui/server): model generate with sae and steering
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Aug 11, 2024
1 parent abf7437 commit b24a702
Showing 1 changed file with 87 additions and 22 deletions.
109 changes: 87 additions & 22 deletions server/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from typing import Any, cast
from typing import Annotated, Any, Literal, cast

import numpy as np
from pydantic import BaseModel
import torch

from transformer_lens.hook_points import HookPoint
from transformers import AutoModelForCausalLM, AutoTokenizer

from transformer_lens import HookedTransformer, HookedTransformerConfig
Expand All @@ -13,14 +15,15 @@

import msgpack

from fastapi import FastAPI, Response
from fastapi import FastAPI, Query, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware

import plotly.express as px
import plotly.graph_objects as go

from lm_saes.analysis.auto_interp import check_description, generate_description
from lm_saes.circuit.context import apply_sae
from lm_saes.config import AutoInterpConfig, LanguageModelConfig, SAEConfig
from lm_saes.database import MongoClient
from lm_saes.sae import SparseAutoEncoder
Expand Down Expand Up @@ -297,31 +300,93 @@ def dictionary_custom_input(dictionary_name: str, input_text: str):

return Response(content=msgpack.packb(sample), media_type="application/x-msgpack")

class SteeringConfig(BaseModel):
sae: str
feature_index: int
steering_type: Literal["times", "add", "set", "ablate"]
steering_value: float | None = None

class ModelGenerateRequest(BaseModel):
input_text: str
max_new_tokens: int = 128
top_k: int = 50
top_p: float = 0.95
return_logits_top_k: int = 5
saes: list[str] = []
steerings: list[SteeringConfig] = []

@app.post("/model/generate")
def model_generate(input_text: str, max_new_tokens: int = 128, top_k: int = 50, top_p: float = 0.95, return_logits_top_k: int = 5):
def model_generate(request: ModelGenerateRequest):
dictionaries = client.list_dictionaries(dictionary_series=dictionary_series)
assert len(dictionaries) > 0, "No dictionaries found. Model name cannot be inferred."
model = get_model(dictionaries[0])
saes = [(get_sae(name), name) for name in request.saes]
max_feature_acts = {
name: client.get_max_feature_acts(name, dictionary_series=dictionary_series)
for _, name in saes
}
assert all(steering.sae in request.saes for steering in request.steerings), "Steering SAE not found"

def generate_steering_hook(steering: SteeringConfig):
def steering_hook(tensor: torch.Tensor, hook: HookPoint):
assert len(tensor.shape) == 3
tensor = tensor.clone()
if steering.steering_type == "times":
assert steering.steering_value is not None
tensor[:, :, steering.feature_index] *= steering.steering_value
elif steering.steering_type == "ablate":
tensor[:, :, steering.feature_index] = 0
elif steering.steering_type == "add":
assert steering.steering_value is not None
tensor[:, :, steering.feature_index] += steering.steering_value
elif steering.steering_type == "set":
assert steering.steering_value is not None
tensor[:, :, steering.feature_index] = steering.steering_value
return tensor
sae = get_sae(steering.sae)
return f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", steering_hook

steerings_hooks = [generate_steering_hook(steering) for steering in request.steerings]

with torch.no_grad():
input = model.to_tokens(input_text, prepend_bos=False)
output = model.generate(input, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p)
output = output.clone()
logits = model.forward(output)
logits_topk = [torch.topk(l, return_logits_top_k) for l in logits[0]]
result = {
"context": [
bytearray([byte_decoder[c] for c in t])
for t in model.tokenizer.convert_ids_to_tokens(output[0])
],
"logits": [l.values.cpu().numpy().tolist() for l in logits_topk],
"logits_tokens": [
[
bytearray([byte_decoder[c] for c in t])
for t in model.tokenizer.convert_ids_to_tokens(l.indices)
] for l in logits_topk
],
"input_mask": [1 for _ in range(len(input[0]))] + [0 for _ in range(len(output[0]) - len(input[0]))],
}
with apply_sae(model, [sae for sae, _ in saes]):
with model.hooks(steerings_hooks):
input = model.to_tokens(request.input_text, prepend_bos=False)
output = model.generate(input, max_new_tokens=request.max_new_tokens, top_k=request.top_k, top_p=request.top_p)
output = output.clone()
logits, cache = model.run_with_cache(output, names_filter=[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts" for sae, _ in saes])
logits_topk = [torch.topk(l, request.return_logits_top_k) for l in logits[0]]
result = {
"context": [
bytearray([byte_decoder[c] for c in t])
for t in model.tokenizer.convert_ids_to_tokens(output[0])
],
"logits": [l.values.cpu().numpy().tolist() for l in logits_topk],
"logits_tokens": [
[
bytearray([byte_decoder[c] for c in t])
for t in model.tokenizer.convert_ids_to_tokens(l.indices)
] for l in logits_topk
],
"input_mask": [1 for _ in range(len(input[0]))] + [0 for _ in range(len(output[0]) - len(input[0]))],
"sae_info": [
{
"name": name,
"feature_acts_indices": [
cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i].nonzero(as_tuple=True)[0].cpu().numpy().tolist()
for i in range(cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0].shape[0])
],
"feature_acts": [
cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i][cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i].nonzero(as_tuple=True)[0]].cpu().numpy().tolist()
for i in range(cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0].shape[0])
],
"max_feature_acts": [
[max_feature_acts[name][j] for j in cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i].nonzero(as_tuple=True)[0].cpu().numpy().tolist()]
for i in range(cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0].shape[0])
]
} for sae, name in saes
],
}
return Response(content=msgpack.packb(result), media_type="application/x-msgpack")


Expand Down

0 comments on commit b24a702

Please sign in to comment.