diff --git a/server/app.py b/server/app.py index 049edf32..345a3e4d 100644 --- a/server/app.py +++ b/server/app.py @@ -1,9 +1,11 @@ import os -from typing import Any, cast +from typing import Annotated, Any, Literal, cast import numpy as np +from pydantic import BaseModel import torch +from transformer_lens.hook_points import HookPoint from transformers import AutoModelForCausalLM, AutoTokenizer from transformer_lens import HookedTransformer, HookedTransformerConfig @@ -13,7 +15,7 @@ import msgpack -from fastapi import FastAPI, Response +from fastapi import FastAPI, Query, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -21,6 +23,7 @@ import plotly.graph_objects as go from lm_saes.analysis.auto_interp import check_description, generate_description +from lm_saes.circuit.context import apply_sae from lm_saes.config import AutoInterpConfig, LanguageModelConfig, SAEConfig from lm_saes.database import MongoClient from lm_saes.sae import SparseAutoEncoder @@ -297,31 +300,93 @@ def dictionary_custom_input(dictionary_name: str, input_text: str): return Response(content=msgpack.packb(sample), media_type="application/x-msgpack") +class SteeringConfig(BaseModel): + sae: str + feature_index: int + steering_type: Literal["times", "add", "set", "ablate"] + steering_value: float | None = None + +class ModelGenerateRequest(BaseModel): + input_text: str + max_new_tokens: int = 128 + top_k: int = 50 + top_p: float = 0.95 + return_logits_top_k: int = 5 + saes: list[str] = [] + steerings: list[SteeringConfig] = [] + @app.post("/model/generate") -def model_generate(input_text: str, max_new_tokens: int = 128, top_k: int = 50, top_p: float = 0.95, return_logits_top_k: int = 5): +def model_generate(request: ModelGenerateRequest): dictionaries = client.list_dictionaries(dictionary_series=dictionary_series) assert len(dictionaries) > 0, "No dictionaries found. Model name cannot be inferred." model = get_model(dictionaries[0]) + saes = [(get_sae(name), name) for name in request.saes] + max_feature_acts = { + name: client.get_max_feature_acts(name, dictionary_series=dictionary_series) + for _, name in saes + } + assert all(steering.sae in request.saes for steering in request.steerings), "Steering SAE not found" + + def generate_steering_hook(steering: SteeringConfig): + def steering_hook(tensor: torch.Tensor, hook: HookPoint): + assert len(tensor.shape) == 3 + tensor = tensor.clone() + if steering.steering_type == "times": + assert steering.steering_value is not None + tensor[:, :, steering.feature_index] *= steering.steering_value + elif steering.steering_type == "ablate": + tensor[:, :, steering.feature_index] = 0 + elif steering.steering_type == "add": + assert steering.steering_value is not None + tensor[:, :, steering.feature_index] += steering.steering_value + elif steering.steering_type == "set": + assert steering.steering_value is not None + tensor[:, :, steering.feature_index] = steering.steering_value + return tensor + sae = get_sae(steering.sae) + return f"{sae.cfg.hook_point_out}.sae.hook_feature_acts", steering_hook + + steerings_hooks = [generate_steering_hook(steering) for steering in request.steerings] + with torch.no_grad(): - input = model.to_tokens(input_text, prepend_bos=False) - output = model.generate(input, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p) - output = output.clone() - logits = model.forward(output) - logits_topk = [torch.topk(l, return_logits_top_k) for l in logits[0]] - result = { - "context": [ - bytearray([byte_decoder[c] for c in t]) - for t in model.tokenizer.convert_ids_to_tokens(output[0]) - ], - "logits": [l.values.cpu().numpy().tolist() for l in logits_topk], - "logits_tokens": [ - [ - bytearray([byte_decoder[c] for c in t]) - for t in model.tokenizer.convert_ids_to_tokens(l.indices) - ] for l in logits_topk - ], - "input_mask": [1 for _ in range(len(input[0]))] + [0 for _ in range(len(output[0]) - len(input[0]))], - } + with apply_sae(model, [sae for sae, _ in saes]): + with model.hooks(steerings_hooks): + input = model.to_tokens(request.input_text, prepend_bos=False) + output = model.generate(input, max_new_tokens=request.max_new_tokens, top_k=request.top_k, top_p=request.top_p) + output = output.clone() + logits, cache = model.run_with_cache(output, names_filter=[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts" for sae, _ in saes]) + logits_topk = [torch.topk(l, request.return_logits_top_k) for l in logits[0]] + result = { + "context": [ + bytearray([byte_decoder[c] for c in t]) + for t in model.tokenizer.convert_ids_to_tokens(output[0]) + ], + "logits": [l.values.cpu().numpy().tolist() for l in logits_topk], + "logits_tokens": [ + [ + bytearray([byte_decoder[c] for c in t]) + for t in model.tokenizer.convert_ids_to_tokens(l.indices) + ] for l in logits_topk + ], + "input_mask": [1 for _ in range(len(input[0]))] + [0 for _ in range(len(output[0]) - len(input[0]))], + "sae_info": [ + { + "name": name, + "feature_acts_indices": [ + cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i].nonzero(as_tuple=True)[0].cpu().numpy().tolist() + for i in range(cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0].shape[0]) + ], + "feature_acts": [ + cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i][cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i].nonzero(as_tuple=True)[0]].cpu().numpy().tolist() + for i in range(cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0].shape[0]) + ], + "max_feature_acts": [ + [max_feature_acts[name][j] for j in cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0][i].nonzero(as_tuple=True)[0].cpu().numpy().tolist()] + for i in range(cache[f"{sae.cfg.hook_point_out}.sae.hook_feature_acts"][0].shape[0]) + ] + } for sae, name in saes + ], + } return Response(content=msgpack.packb(result), media_type="application/x-msgpack")