pip install git+https://github.com/tomekkorbak/kl-gpt3.git transformers
export OPENAI_API_KEY=sk-YOURKEY
from transformers import AutoModelForCausalLM
from kl_gpt3.kl_gpt3 import evaluate_forward_kl
gpt2 = AutoModelForCausalLM.from_pretrained('gpt2')
kl = evaluate_forward_kl(gpt2, max_tokens=32, num_samples=4)
print(kl)
- handle gpt3 api timeout nicely
- docstrings
- add tests