diff --git a/posts/2024-forecast.ipynb b/posts/2024-forecast.ipynb index 3faee0c..7dfb679 100644 --- a/posts/2024-forecast.ipynb +++ b/posts/2024-forecast.ipynb @@ -539,11 +539,75 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 114, "id": "a06840b9", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from transformers import InformerForPrediction, InformerConfig\n", + "\n", + "config = InformerConfig(\n", + " input_dim=context_length,\n", + " prediction_length=prediction_length,\n", + " num_heads=4,\n", + " encoder_layers=2,\n", + " decoder_layers=2,\n", + " use_mask=True,\n", + " forecast=True\n", + ")\n", + "\n", + "informer = InformerForPrediction(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "5465e218", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "134531" + ] + }, + "execution_count": 116, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sum(p.numel() for p in informer.parameters())" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "4f46f6a0", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "forward() missing 2 required positional arguments: 'past_time_features' and 'past_observed_mask'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[117], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_epochs):\n\u001b[1;32m 6\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m----> 7\u001b[0m y_pred \u001b[38;5;241m=\u001b[39m \u001b[43minformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mXs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(y_pred, ys)\n\u001b[1;32m 9\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n", + "File \u001b[0;32m~/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", + "\u001b[0;31mTypeError\u001b[0m: forward() missing 2 required positional arguments: 'past_time_features' and 'past_observed_mask'" + ] + } + ], + "source": [ + "# training loop\n", + "n_epochs = 100\n", + "\n", + "start_time = time.time()\n", + "for epoch in range(n_epochs):\n", + " optimizer.zero_grad()\n", + " y_pred = informer(Xs, past\n" + ] } ], "metadata": {