diff --git a/notebooks/vis_demo.ipynb b/notebooks/vis_demo.ipynb index aeeac903..927584a2 100644 --- a/notebooks/vis_demo.ipynb +++ b/notebooks/vis_demo.ipynb @@ -9,19 +9,19 @@ "import torch; torch.set_grad_enabled(False)\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "\n", - "from delphi.eval.utils import tokenize, get_correct_and_top_probs, load_validation_dataset, load_text_from_dataset\n", + "from delphi.eval.utils import tokenize, get_next_and_top_k_probs, load_validation_dataset\n", "from delphi.eval.vis import vis_sample_prediction_probs\n", "\n", "model_name = \"roneneldan/TinyStories-1M\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForCausalLM.from_pretrained(model_name)\n", "ds = load_validation_dataset(\"tinystories-v2-clean\")\n", - "ds_txt = load_text_from_dataset(ds)[:100]\n", + "ds_txt = ds[\"story\"][:100]\n", "ds_tok = [tokenize(tokenizer, txt) for txt in ds_txt]\n", "sample_tok = ds_tok[0]\n", "\n", - "correct_probs, top_3_probs = get_correct_and_top_probs(model, sample_tok, top_k=3)\n", - "_, top_5_probs = get_correct_and_top_probs(model, sample_tok, top_k=5)" + "correct_probs, top_3_probs = get_next_and_top_k_probs(model, sample_tok, k=3)\n", + "_, top_5_probs = get_next_and_top_k_probs(model, sample_tok, k=5)" ] }, { @@ -40,12 +40,12 @@ "data": { "text/html": [ "\n", - " \n", - "