From c36f99ff1479f1525b86909897c667efffc14ac3 Mon Sep 17 00:00:00 2001 From: BorjaRequena Date: Wed, 18 Oct 2023 08:40:03 +0200 Subject: [PATCH] zero gradients!! --- nbs/metatutorial.ipynb | 351 ++++++++++++++++++++++++++++++++++------- 1 file changed, 293 insertions(+), 58 deletions(-) diff --git a/nbs/metatutorial.ipynb b/nbs/metatutorial.ipynb index 10ec1ab..13f545c 100644 --- a/nbs/metatutorial.ipynb +++ b/nbs/metatutorial.ipynb @@ -102,7 +102,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "max_num = 1_000_000\n", @@ -119,7 +123,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "name": "stdout", @@ -171,7 +179,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -200,7 +212,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class Tokenizer:\n", @@ -220,7 +236,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "tkn = Tokenizer(vocab)" @@ -236,7 +256,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -273,7 +297,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -301,7 +329,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -371,7 +403,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "#| hide\n", @@ -395,7 +431,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -430,7 +470,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "val_pct = 0.1\n", @@ -442,7 +486,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -469,7 +517,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "def get_batch(data, batch_size, seq_len):\n", @@ -482,7 +534,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "batch_size = 64\n", @@ -493,7 +549,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -513,7 +573,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -543,7 +607,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class BigramLanguageModel(nn.Module):\n", @@ -570,7 +638,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "bigram_model = BigramLanguageModel(vocab_size).to(device)" @@ -579,7 +651,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -610,7 +686,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -642,7 +722,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "def cross_entropy_loss(logits, targets):\n", @@ -662,7 +746,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "optimizer = torch.optim.AdamW(bigram_model.parameters(), lr=1e-3)" @@ -678,7 +766,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "name": "stdout", @@ -695,6 +787,8 @@ "\n", "for _ in range(train_steps):\n", " xb, yb = get_batch(data_train, batch_size, seq_len)\n", + " \n", + " optimizer.zero_grad()\n", " logits = bigram_model(xb)\n", " loss = cross_entropy_loss(logits, yb)\n", " loss.backward()\n", @@ -713,13 +807,18 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "def train_model(steps, model, lr, batch_sz, seq_len):\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", " for i in range(steps):\n", " xb, yb = get_batch(data_train, batch_sz, seq_len)\n", + " optimizer.zero_grad()\n", " logits = model(xb)\n", " loss = cross_entropy_loss(logits, yb)\n", " loss.backward()\n", @@ -760,7 +859,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -790,7 +893,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -862,7 +969,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class AttentionHead(nn.Module):\n", @@ -894,7 +1005,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "#| hide\n", @@ -904,7 +1019,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "batch_size, seq_len = 1, 8 \n", @@ -914,7 +1033,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "name": "stdout", @@ -941,7 +1064,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class CausalAttentionHead(nn.Module):\n", @@ -979,7 +1106,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class GPT(nn.Module):\n", @@ -1011,7 +1142,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "torch.manual_seed(7)\n", @@ -1023,7 +1158,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "name": "stdout", @@ -1042,7 +1181,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -1078,7 +1221,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class GPT(nn.Module):\n", @@ -1113,7 +1260,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "torch.manual_seed(7)\n", @@ -1125,7 +1276,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "name": "stdout", @@ -1147,7 +1302,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -1180,7 +1339,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class MultiHeadAttention(nn.Module):\n", @@ -1212,7 +1375,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class GPT(nn.Module):\n", @@ -1247,7 +1414,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "torch.manual_seed(7)\n", @@ -1259,7 +1430,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "name": "stdout", @@ -1301,7 +1476,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -1336,7 +1515,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class FeedForward(nn.Module):\n", @@ -1369,7 +1552,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class DecoderBlock(nn.Module):\n", @@ -1397,7 +1584,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "class GPT(nn.Module):\n", @@ -1433,7 +1624,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "torch.manual_seed(7)\n", @@ -1445,7 +1640,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "name": "stdout", @@ -1487,7 +1686,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -1516,7 +1719,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -1548,7 +1755,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [], "source": [ "torch.manual_seed(7)\n", @@ -1560,7 +1771,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "name": "stdout", @@ -1627,7 +1842,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -1656,7 +1875,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -1677,7 +1900,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -1705,7 +1932,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": { @@ -1733,7 +1964,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "vscode": { + "languageId": "python" + } + }, "outputs": [ { "data": {