diff --git a/Fine_Tuning_BERT_for_Spam_Classification.ipynb b/Fine_Tuning_BERT_for_Spam_Classification.ipynb index d094dda..f83c744 100644 --- a/Fine_Tuning_BERT_for_Spam_Classification.ipynb +++ b/Fine_Tuning_BERT_for_Spam_Classification.ipynb @@ -1,2038 +1,2120 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Fine-Tuning BERT for Spam Classification.ipynb", - "provenance": [], - "authorship_tag": "ABX9TyOt9x7x5Cm/ENCEI4+c+LvL", - "include_colab_link": true + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "OFOTiqrtNvyy" + }, + "source": [ + "# Install Transformers Library" + ] }, - "accelerator": "GPU", - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "983fea7c2dc74dfaba7aa60147af85d1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_ccf5f7e5cc10493ca9c44b14fdec31dc", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_59bae99ad63d4a3a8b8d622d95f7ad07", - "IPY_MODEL_689e66a8dff249449b5f0f5bbfffa037" + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1hkhc10wNrGt", + "vscode": { + "languageId": "python" + } + }, + "outputs": [], + "source": [ + "!pip install transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "x4giRzM7NtHJ", + "vscode": { + "languageId": "python" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import classification_report\n", + "import transformers\n", + "from transformers import AutoModel, BertTokenizerFast\n", + "\n", + "# specify GPU\n", + "device = torch.device(\"cuda\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kKd-Tj3hOMsZ" + }, + "source": [ + "# Load Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + }, + "colab_type": "code", + "id": "cwJrQFQgN_BE", + "outputId": "854f0b55-e330-4806-cc32-79643e6bd721", + "vscode": { + "languageId": "python" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labeltext
00Go until jurong point, crazy.. Available only ...
10Ok lar... Joking wif u oni...
21Free entry in 2 a wkly comp to win FA Cup fina...
30U dun say so early hor... U c already then say...
40Nah I don't think he goes to usf, he lives aro...
\n", + "
" + ], + "text/plain": [ + " label text\n", + "0 0 Go until jurong point, crazy.. Available only ...\n", + "1 0 Ok lar... Joking wif u oni...\n", + "2 1 Free entry in 2 a wkly comp to win FA Cup fina...\n", + "3 0 U dun say so early hor... U c already then say...\n", + "4 0 Nah I don't think he goes to usf, he lives aro..." ] - } + }, + "execution_count": 3, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv(\"spamdata_v2.csv\")\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 }, - "ccf5f7e5cc10493ca9c44b14fdec31dc": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } + "colab_type": "code", + "id": "fzPPOrVQWiW5", + "outputId": "e8555c2b-a50d-4809-833f-adf3ac349a1b", + "vscode": { + "languageId": "python" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(5572, 2)" + ] + }, + "execution_count": 28, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 }, - "59bae99ad63d4a3a8b8d622d95f7ad07": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_49dd79a9a65044ba8345deb250ce4b24", - "_dom_classes": [], - "description": "Downloading: 100%", - "_model_name": "FloatProgressModel", - "bar_style": "success", - "max": 433, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 433, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_47862cd626cf46619a5cc505fde02276" - } + "colab_type": "code", + "id": "676DPU1BOPdp", + "outputId": "075808af-7b2e-4f0d-e888-e06e2e8abbf9", + "vscode": { + "languageId": "python" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0 0.865937\n", + "1 0.134063\n", + "Name: label, dtype: float64" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# check class distribution\n", + "df['label'].value_counts(normalize = True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MKfWnApvOoE7" + }, + "source": [ + "# Split train dataset into train, validation and test sets" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "mfhSPF5jOWb7", + "vscode": { + "languageId": "python" + } + }, + "outputs": [], + "source": [ + "train_text, temp_text, train_labels, temp_labels = train_test_split(df['text'], df['label'], \n", + " random_state=2018, \n", + " test_size=0.3, \n", + " stratify=df['label'])\n", + "\n", + "# we will use temp_text and temp_labels to create validation and test set\n", + "val_text, test_text, val_labels, test_labels = train_test_split(temp_text, temp_labels, \n", + " random_state=2018, \n", + " test_size=0.5, \n", + " stratify=temp_labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "n7hsdLoCO7uB" + }, + "source": [ + "# Import BERT Model and BERT Tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 164, + "referenced_widgets": [ + "983fea7c2dc74dfaba7aa60147af85d1", + "ccf5f7e5cc10493ca9c44b14fdec31dc", + "59bae99ad63d4a3a8b8d622d95f7ad07", + "689e66a8dff249449b5f0f5bbfffa037", + "49dd79a9a65044ba8345deb250ce4b24", + "47862cd626cf46619a5cc505fde02276", + "cb5f7a2a5bcb47649703cc633f2fb685", + "6d2355752eb74f348596a380d2347b73", + "0e580433bec2453da54c0ce9ee027401", + "bdfd7634b8bf42aa8794ada8d9e47173", + "73fc7587f1bb49df8a0fc87ecfac7f3c", + "4da1c15300b2468ab1a7e2df800ed39b", + "645d520e8a1c4f1fa202c6c68c5ce6af", + "3bb6b624b4ce4be788c38cb8d1936177", + "ed9f97f9d12a49aa939b7595ad3cb27c", + "88214abee8b9462f86369072d858ae9f", + "b4bef5a685954e238b52c43eefe4c9e5", + "94264f36ceb64d3881fed952bf579072", + "8cf06dad410440f78c50e3527e858905", + "edf0e4c1ae214a66abb717b20e3bffad", + "4d56a47453914de3be15d7a515e5b210", + "6831c2b733d74b31a41cb6eb971a25d7", + "58cd585c531444a5b8613c2d85bab022", + "b6423c858927455e8cbc5a953273466a" + ] }, - "689e66a8dff249449b5f0f5bbfffa037": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_cb5f7a2a5bcb47649703cc633f2fb685", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 433/433 [00:00<00:00, 1.98kB/s]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_6d2355752eb74f348596a380d2347b73" - } + "colab_type": "code", + "id": "S1kY3gZjO2RE", + "outputId": "4194574c-05d6-4d1d-c4c0-8dd89913ff79", + "vscode": { + "languageId": "python" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "983fea7c2dc74dfaba7aa60147af85d1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" }, - "49dd79a9a65044ba8345deb250ce4b24": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] }, - "47862cd626cf46619a5cc505fde02276": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0e580433bec2453da54c0ce9ee027401", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" }, - "cb5f7a2a5bcb47649703cc633f2fb685": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] }, - "6d2355752eb74f348596a380d2347b73": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "0e580433bec2453da54c0ce9ee027401": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_bdfd7634b8bf42aa8794ada8d9e47173", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_73fc7587f1bb49df8a0fc87ecfac7f3c", - "IPY_MODEL_4da1c15300b2468ab1a7e2df800ed39b" + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b4bef5a685954e238b52c43eefe4c9e5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…" ] - } - }, - "bdfd7634b8bf42aa8794ada8d9e47173": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" }, - "73fc7587f1bb49df8a0fc87ecfac7f3c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_645d520e8a1c4f1fa202c6c68c5ce6af", - "_dom_classes": [], - "description": "Downloading: 100%", - "_model_name": "FloatProgressModel", - "bar_style": "success", - "max": 440473133, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 440473133, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_3bb6b624b4ce4be788c38cb8d1936177" - } + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# import BERT-base pretrained model\n", + "bert = AutoModel.from_pretrained('bert-base-uncased')\n", + "\n", + "# Load the BERT tokenizer\n", + "tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_zOKeOMeO-DT", + "vscode": { + "languageId": "python" + } + }, + "outputs": [], + "source": [ + "# sample data\n", + "text = [\"this is a bert model tutorial\", \"we will fine-tune a bert model\"]\n", + "\n", + "# encode text\n", + "sent_id = tokenizer.batch_encode_plus(text, padding=True, return_token_type_ids=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 54 }, - "4da1c15300b2468ab1a7e2df800ed39b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_ed9f97f9d12a49aa939b7595ad3cb27c", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 440M/440M [00:11<00:00, 37.5MB/s]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_88214abee8b9462f86369072d858ae9f" - } + "colab_type": "code", + "id": "oAH73n39PHLw", + "outputId": "17b76300-71f2-464c-90b0-a5907f4f675a", + "vscode": { + "languageId": "python" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [[101, 2023, 2003, 1037, 14324, 2944, 14924, 4818, 102, 0], [101, 2057, 2097, 2986, 1011, 8694, 1037, 14324, 2944, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}\n" + ] + } + ], + "source": [ + "# output\n", + "print(sent_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "8wIYaWI_Prg8" + }, + "source": [ + "# Tokenization" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 282 }, - "645d520e8a1c4f1fa202c6c68c5ce6af": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } + "colab_type": "code", + "id": "yKwbpeN_PMiu", + "outputId": "9f843240-6cf4-46c9-80b7-dc9ab4e03602", + "vscode": { + "languageId": "python" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" }, - "3bb6b624b4ce4be788c38cb8d1936177": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "ed9f97f9d12a49aa939b7595ad3cb27c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "88214abee8b9462f86369072d858ae9f": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "b4bef5a685954e238b52c43eefe4c9e5": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "state": { - "_view_name": "HBoxView", - "_dom_classes": [], - "_model_name": "HBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_94264f36ceb64d3881fed952bf579072", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_8cf06dad410440f78c50e3527e858905", - "IPY_MODEL_edf0e4c1ae214a66abb717b20e3bffad" + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD4CAYAAADrRI2NAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAUhklEQVR4nO3df5Dcd13H8efbxhboYdIfzE0niV7QiFMbleamrYMyd8aBNEVSFZl2OpBgnYxji8XWoUFG66jMBBURRsSJpkPQyhURprEtQgw9Gf5IpamlSVtKryVIbkIqtASPVjH69o/9nG7Pu9z+SHZv83k+Zm7uu5/vZ3df++32td/97vc2kZlIkurxXf0OIEnqLYtfkipj8UtSZSx+SaqMxS9JlVnW7wAnc+GFF+bIyEjb1/v2t7/Nueeee+oDnUaDlnnQ8oKZe2XQMg9aXlg884EDB76emS9bcEJmLtmf9evXZyfuu+++jq7XT4OWedDyZpq5VwYt86DlzVw8M/BAnqRbPdQjSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVWdJf2dArI9vvaWne4R1XneYkknT6uccvSZVZtPgj4vaIeDoiDjWN/UFEfDEiHo6IT0TEiqZ174iIqYh4PCJe2zS+sYxNRcT2U/9QJEmtaGWP/0PAxjlje4FLMvNHgC8B7wCIiIuBa4AfLtf504g4KyLOAj4AXAlcDFxb5kqSemzR4s/MzwLPzBn7dGaeKBf3A6vK8mZgIjP/IzO/DEwBl5Wfqcx8KjO/A0yUuZKkHovGN3guMiliBLg7My+ZZ93fAXdm5l9FxJ8A+zPzr8q6XcAny9SNmflLZfxNwOWZeeM8t7cN2AYwPDy8fmJiou0HNTMzw9DQUMvzD04fb2neupXL287SqnYz99ug5QUz98qgZR60vLB45vHx8QOZObrQ+q7O6omIdwIngDu6uZ1mmbkT2AkwOjqaY2Njbd/G5OQk7Vxva6tn9VzXfpZWtZu53wYtL5i5VwYt86Dlhe4zd1z8EbEVeB2wIf/vbcM0sLpp2qoyxknGJUk91NHpnBGxEXg78PrMfK5p1R7gmog4JyLWAGuBfwI+D6yNiDURcTaND4D3dBddktSJRff4I+IjwBhwYUQcAW6jcRbPOcDeiIDGcf1fzsxHIuKjwKM0DgHdkJn/VW7nRuBTwFnA7Zn5yGl4PJKkRSxa/Jl57TzDu04y/13Au+YZvxe4t610kqRTzr/claTKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVsfglqTIWvyRVxuKXpMpY/JJUGYtfkipj8UtSZSx+SaqMxS9JlbH4JakyFr8kVcbil6TKLFr8EXF7RDwdEYeaxs6PiL0R8UT5fV4Zj4h4f0RMRcTDEXFp03W2lPlPRMSW0/NwJEmLaWWP/0PAxjlj24F9mbkW2FcuA1wJrC0/24APQuOFArgNuBy4DLht9sVCktRbixZ/Zn4WeGbO8GZgd1neDVzdNP7hbNgPrIiIi4DXAnsz85nMfBbYy/9/MZEk9UBk5uKTIkaAuzPzknL5m5m5oiwH8GxmroiIu4Edmfm5sm4fcCswBrwoM3+vjP8m8Hxm/uE897WNxrsFhoeH109MTLT9oGZmZhgaGmp5/sHp4y3NW7dyedtZWtVu5n4btLxg5l4ZtMyDlhcWzzw+Pn4gM0cXWr+s2wCZmRGx+KtH67e3E9gJMDo6mmNjY23fxuTkJO1cb+v2e1qad/i69rO0qt3M/TZoecHMvTJomQctL3SfudOzeo6VQziU30+X8WlgddO8VWVsoXFJUo91Wvx7gNkzc7YAdzWNv7mc3XMFcDwzjwKfAl4TEeeVD3VfU8YkST226KGeiPgIjWP0F0bEERpn5+wAPhoR1wNfAd5Ypt8LbAKmgOeAtwBk5jMR8bvA58u838nMuR8YS5J6YNHiz8xrF1i1YZ65CdywwO3cDtzeVjpJ0innX+5KUmUsfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVsfglqTIWvyRVxuKXpMpY/JJUGYtfkipj8UtSZSx+SaqMxS9JlbH4JakyFr8kVcbil6TKWPySVBmLX5Iq01XxR8SvRcQjEXEoIj4SES+KiDURcX9ETEXEnRFxdpl7Trk8VdaPnIoHIElqT8fFHxErgV8FRjPzEuAs4Brg3cB7M/MHgGeB68tVrgeeLePvLfMkST3W7aGeZcCLI2IZ8BLgKPBTwMfK+t3A1WV5c7lMWb8hIqLL+5cktSkys/MrR9wEvAt4Hvg0cBOwv+zVExGrgU9m5iURcQjYmJlHyrongcsz8+tzbnMbsA1geHh4/cTERNu5ZmZmGBoaann+wenjLc1bt3J521la1W7mfhu0vGDmXhm0zIOWFxbPPD4+fiAzRxdav6zTO46I82jsxa8Bvgn8DbCx09ublZk7gZ0Ao6OjOTY21vZtTE5O0s71tm6/p6V5h69rP0ur2s3cb4OWF8zcK4OWedDyQveZuznU89PAlzPzXzPzP4GPA68CVpRDPwCrgOmyPA2sBijrlwPf6OL+JUkd6Kb4/wW4IiJeUo7VbwAeBe4D3lDmbAHuKst7ymXK+s9kN8eZJEkd6bj4M/N+Gh/SPggcLLe1E7gVuDkipoALgF3lKruAC8r4zcD2LnJLkjrU8TF+gMy8DbhtzvBTwGXzzP134Be6ub92jbR47F6SauJf7kpSZSx+SaqMxS9JlbH4JakyFr8kVcbil6TKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVsfglqTIWvyRVxuKXpMpY/JJUGYtfkipj8UtSZboq/ohYEREfi4gvRsRjEfHjEXF+ROyNiCfK7/PK3IiI90fEVEQ8HBGXnpqHIElqR7d7/O8D/j4zfwj4UeAxYDuwLzPXAvvKZYArgbXlZxvwwS7vW5LUgY6LPyKWA68GdgFk5ncy85vAZmB3mbYbuLosbwY+nA37gRURcVHHySVJHYnM7OyKET8G7AQepbG3fwC4CZjOzBVlTgDPZuaKiLgb2JGZnyvr9gG3ZuYDc253G413BAwPD6+fmJhoO9vMzAxDQ0McnD7e0WNbyLqVy0/p7TWbzTwoBi0vmLlXBi3zoOWFxTOPj48fyMzRhdYv6+K+lwGXAm/NzPsj4n3832EdADIzI6KtV5bM3EnjBYXR0dEcGxtrO9jk5CRjY2Ns3X5P29c9mcPXtZ+lVbOZB8Wg5QUz98qgZR60vNB95m6O8R8BjmTm/eXyx2i8EBybPYRTfj9d1k8Dq5uuv6qMSZJ6qOPiz8yvAV+NiFeUoQ00DvvsAbaUsS3AXWV5D/DmcnbPFcDxzDza6f1LkjrTzaEegLcCd0TE2cBTwFtovJh8NCKuB74CvLHMvRfYBEwBz5W5kqQe66r4M/MhYL4PEDbMMzeBG7q5P0lS9/zLXUmqjMUvSZWx+CWpMt1+uKsujDT9ncEt604s+HcHh3dc1atIkirgHr8kVcbil6TKWPySVBmLX5Iq44e7bRhp8Uvf/DBW0lLmHr8kVcbil6TKWPySVBmLX5IqY/FLUmU8q+c0aPXsH0nqB/f4JakyFr8kVcbil6TKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZWx+CWpMl0Xf0ScFRH/HBF3l8trIuL+iJiKiDsj4uwyfk65PFXWj3R735Kk9p2KPf6bgMeaLr8beG9m/gDwLHB9Gb8eeLaMv7fMkyT1WFfFHxGrgKuAvyiXA/gp4GNlym7g6rK8uVymrN9Q5kuSeqjbPf4/Bt4O/He5fAHwzcw8US4fAVaW5ZXAVwHK+uNlviSphyIzO7tixOuATZn5KxExBvw6sBXYXw7nEBGrgU9m5iURcQjYmJlHyrongcsz8+tzbncbsA1geHh4/cTERNvZZmZmGBoa4uD08Y4eWz8MvxiOPT//unUrl/c2TAtmt/EgMXNvDFrmQcsLi2ceHx8/kJmjC63v5muZXwW8PiI2AS8Cvgd4H7AiIpaVvfpVwHSZPw2sBo5ExDJgOfCNuTeamTuBnQCjo6M5NjbWdrDJyUnGxsbYOkBfj3zLuhO85+D8/zkOXzfW2zAtmN3Gg8TMvTFomQctL3SfueNDPZn5jsxclZkjwDXAZzLzOuA+4A1l2hbgrrK8p1ymrP9Mdvp2Q5LUsdNxHv+twM0RMUXjGP6uMr4LuKCM3wxsPw33LUlaxCn5F7gycxKYLMtPAZfNM+ffgV84FfcnSeqcf7krSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVsfglqTIWvyRVxuKXpMpY/JJUGYtfkipj8UtSZSx+SaqMxS9JlbH4JakyFr8kVcbil6TKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZXpuPgjYnVE3BcRj0bEIxFxUxk/PyL2RsQT5fd5ZTwi4v0RMRURD0fEpafqQUiSWresi+ueAG7JzAcj4qXAgYjYC2wF9mXmjojYDmwHbgWuBNaWn8uBD5bfWsTI9ntannt4x1WnMYmkM0HHe/yZeTQzHyzL/wY8BqwENgO7y7TdwNVleTPw4WzYD6yIiIs6Ti5J6khkZvc3EjECfBa4BPiXzFxRxgN4NjNXRMTdwI7M/FxZtw+4NTMfmHNb24BtAMPDw+snJibazjMzM8PQ0BAHp493/qB6bPjFcOz57m9n3crl3d9IC2a38SAxc28MWuZBywuLZx4fHz+QmaMLre/mUA8AETEE/C3wtsz8VqPrGzIzI6KtV5bM3AnsBBgdHc2xsbG2M01OTjI2NsbWNg6R9Nst607wnoNd/+fg8HVj3Ydpwew2HiRm7o1ByzxoeaH7zF2d1RMR302j9O/IzI+X4WOzh3DK76fL+DSwuunqq8qYJKmHujmrJ4BdwGOZ+UdNq/YAW8ryFuCupvE3l7N7rgCOZ+bRTu9fktSZbo4tvAp4E3AwIh4qY78B7AA+GhHXA18B3ljW3QtsAqaA54C3dHHfkqQOdVz85UPaWGD1hnnmJ3BDp/cnSTo1/MtdSaqMxS9JlbH4JakyFr8kVcbil6TKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5Jqkw3//SilqCR7fe0NO/wjqtOcxJJS5V7/JJUGYtfkipj8UtSZSx+SaqMxS9JlbH4Jakyns5ZKU/7lOrV8+KPiI3A+4CzgL/IzB29zqD+8QVH6r+eHuqJiLOADwBXAhcD10bExb3MIEm16/Ue/2XAVGY+BRARE8Bm4NEe51CLFtpDv2XdCba2uPd+OrX6DgJaz9zqu4127rsVvstRr0Rm9u7OIt4AbMzMXyqX3wRcnpk3Ns3ZBmwrF18BPN7BXV0IfL3LuL02aJkHLS+YuVcGLfOg5YXFM39fZr5soZVL7sPdzNwJ7OzmNiLigcwcPUWRemLQMg9aXjBzrwxa5kHLC91n7vXpnNPA6qbLq8qYJKlHel38nwfWRsSaiDgbuAbY0+MMklS1nh7qycwTEXEj8Ckap3PenpmPnIa76upQUZ8MWuZBywtm7pVByzxoeaHbw+G9/HBXktR/fmWDJFXG4pekypxRxR8RGyPi8YiYiojt/c4zn4hYHRH3RcSjEfFIRNxUxn87IqYj4qHys6nfWZtFxOGIOFiyPVDGzo+IvRHxRPl9Xr9zzoqIVzRty4ci4lsR8baltp0j4vaIeDoiDjWNzbtdo+H95fn9cERcukTy/kFEfLFk+kRErCjjIxHxfNO2/rNe5z1J5gWfBxHxjrKNH4+I1y6hzHc25T0cEQ+V8fa3c2aeET80Pix+Eng5cDbwBeDifueaJ+dFwKVl+aXAl2h8fcVvA7/e73wnyX0YuHDO2O8D28vyduDd/c55kufG14DvW2rbGXg1cClwaLHtCmwCPgkEcAVw/xLJ+xpgWVl+d1PekeZ5S2wbz/s8KP8vfgE4B1hTOuWspZB5zvr3AL/V6XY+k/b4//frIDLzO8Ds10EsKZl5NDMfLMv/BjwGrOxvqo5tBnaX5d3A1X3McjIbgCcz8yv9DjJXZn4WeGbO8ELbdTPw4WzYD6yIiIt6k7RhvryZ+enMPFEu7qfx9zlLxgLbeCGbgYnM/I/M/DIwRaNbeupkmSMigDcCH+n09s+k4l8JfLXp8hGWeKFGxAjwSuD+MnRjebt8+1I6bFIk8OmIOFC+VgNgODOPluWvAcP9ibaoa3jh/yRLeTvDwtt1EJ7jv0jjXcmsNRHxzxHxjxHxk/0KtYD5ngeDsI1/EjiWmU80jbW1nc+k4h8oETEE/C3wtsz8FvBB4PuBHwOO0ngrt5T8RGZeSuObVW+IiFc3r8zGe84ld25w+UPB1wN/U4aW+nZ+gaW6XecTEe8ETgB3lKGjwPdm5iuBm4G/jojv6Ve+OQbqeTDHtbxwR6bt7XwmFf/AfB1ERHw3jdK/IzM/DpCZxzLzvzLzv4E/pw9vL08mM6fL76eBT9DId2z2UEP5/XT/Ei7oSuDBzDwGS387Fwtt1yX7HI+IrcDrgOvKixXlcMk3yvIBGsfLf7BvIZuc5HmwZLcxQEQsA34OuHN2rJPtfCYV/0B8HUQ5PrcLeCwz/6hpvPlY7c8Ch+Zet18i4tyIeOnsMo0P8w7R2L5byrQtwF39SXhSL9g7WsrbuclC23UP8OZyds8VwPGmQ0J9E41/XOntwOsz87mm8ZdF49/gICJeDqwFnupPyhc6yfNgD3BNRJwTEWtoZP6nXuc7iZ8GvpiZR2YHOtrOvf60+jR/Er6JxlkyTwLv7HeeBTL+BI237g8DD5WfTcBfAgfL+B7gon5nbcr8chpnOnwBeGR22wIXAPuAJ4B/AM7vd9Y5uc8FvgEsbxpbUtuZxovSUeA/aRxPvn6h7UrjbJ4PlOf3QWB0ieSdonFcfPb5/Gdl7s+X58tDwIPAzyyhbbzg8wB4Z9nGjwNXLpXMZfxDwC/Pmdv2dvYrGySpMmfSoR5JUgssfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klSZ/wEmDJk6BAzVDQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" ] - } - }, - "94264f36ceb64d3881fed952bf579072": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "8cf06dad410440f78c50e3527e858905": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "state": { - "_view_name": "ProgressView", - "style": "IPY_MODEL_4d56a47453914de3be15d7a515e5b210", - "_dom_classes": [], - "description": "Downloading: 100%", - "_model_name": "FloatProgressModel", - "bar_style": "success", - "max": 231508, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 231508, - "_view_count": null, - "_view_module_version": "1.5.0", - "orientation": "horizontal", - "min": 0, - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_6831c2b733d74b31a41cb6eb971a25d7" - } - }, - "edf0e4c1ae214a66abb717b20e3bffad": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "state": { - "_view_name": "HTMLView", - "style": "IPY_MODEL_58cd585c531444a5b8613c2d85bab022", - "_dom_classes": [], - "description": "", - "_model_name": "HTMLModel", - "placeholder": "​", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": " 232k/232k [00:39<00:00, 5.82kB/s]", - "_view_count": null, - "_view_module_version": "1.5.0", - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_b6423c858927455e8cbc5a953273466a" - } - }, - "4d56a47453914de3be15d7a515e5b210": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "ProgressStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "bar_color": null, - "_model_module": "@jupyter-widgets/controls" - } - }, - "6831c2b733d74b31a41cb6eb971a25d7": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "58cd585c531444a5b8613c2d85bab022": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "state": { - "_view_name": "StyleView", - "_model_name": "DescriptionStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "b6423c858927455e8cbc5a953273466a": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" } - } - } - }, - "cells": [ + ], + "source": [ + "# get length of all the messages in the train set\n", + "seq_len = [len(i.split()) for i in train_text]\n", + "\n", + "pd.Series(seq_len).hist(bins = 30)" + ] + }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 10, "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab": {}, + "colab_type": "code", + "id": "OXcswEIRPvGe", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "\"Open" + "max_seq_len = 25" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tk5S7DWaP2t6", + "vscode": { + "languageId": "python" + } + }, + "outputs": [], + "source": [ + "# tokenize and encode sequences in the training set\n", + "tokens_train = tokenizer.batch_encode_plus(\n", + " train_text.tolist(),\n", + " max_length = max_seq_len,\n", + " pad_to_max_length=True,\n", + " truncation=True,\n", + " return_token_type_ids=False\n", + ")\n", + "\n", + "# tokenize and encode sequences in the validation set\n", + "tokens_val = tokenizer.batch_encode_plus(\n", + " val_text.tolist(),\n", + " max_length = max_seq_len,\n", + " pad_to_max_length=True,\n", + " truncation=True,\n", + " return_token_type_ids=False\n", + ")\n", + "\n", + "# tokenize and encode sequences in the test set\n", + "tokens_test = tokenizer.batch_encode_plus(\n", + " test_text.tolist(),\n", + " max_length = max_seq_len,\n", + " pad_to_max_length=True,\n", + " truncation=True,\n", + " return_token_type_ids=False\n", + ")" ] }, { "cell_type": "markdown", "metadata": { - "id": "OFOTiqrtNvyy", - "colab_type": "text" + "colab_type": "text", + "id": "Wsm8bkRZQTw9" }, "source": [ - "# Install Transformers Library" + "# Convert Integer Sequences to Tensors" ] }, { "cell_type": "code", + "execution_count": 12, "metadata": { - "id": "1hkhc10wNrGt", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "QR-lXwmzQPd6", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "!pip install transformers" - ], - "execution_count": null, - "outputs": [] + "# for train set\n", + "train_seq = torch.tensor(tokens_train['input_ids'])\n", + "train_mask = torch.tensor(tokens_train['attention_mask'])\n", + "train_y = torch.tensor(train_labels.tolist())\n", + "\n", + "# for validation set\n", + "val_seq = torch.tensor(tokens_val['input_ids'])\n", + "val_mask = torch.tensor(tokens_val['attention_mask'])\n", + "val_y = torch.tensor(val_labels.tolist())\n", + "\n", + "# for test set\n", + "test_seq = torch.tensor(tokens_test['input_ids'])\n", + "test_mask = torch.tensor(tokens_test['attention_mask'])\n", + "test_y = torch.tensor(test_labels.tolist())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Ov1cOBlcRLuk" + }, + "source": [ + "# Create DataLoaders" + ] }, { "cell_type": "code", + "execution_count": 13, "metadata": { - "id": "x4giRzM7NtHJ", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "qUy9JKFYQYLp", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import classification_report\n", - "import transformers\n", - "from transformers import AutoModel, BertTokenizerFast\n", + "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n", "\n", - "# specify GPU\n", - "device = torch.device(\"cuda\")" - ], - "execution_count": 2, - "outputs": [] + "#define a batch size\n", + "batch_size = 32\n", + "\n", + "# wrap tensors\n", + "train_data = TensorDataset(train_seq, train_mask, train_y)\n", + "\n", + "# sampler for sampling the data during training\n", + "train_sampler = RandomSampler(train_data)\n", + "\n", + "# dataLoader for train set\n", + "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n", + "\n", + "# wrap tensors\n", + "val_data = TensorDataset(val_seq, val_mask, val_y)\n", + "\n", + "# sampler for sampling the data during training\n", + "val_sampler = SequentialSampler(val_data)\n", + "\n", + "# dataLoader for validation set\n", + "val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)" + ] }, { "cell_type": "markdown", "metadata": { - "id": "kKd-Tj3hOMsZ", - "colab_type": "text" + "colab_type": "text", + "id": "K2HZc5ZYRV28" }, "source": [ - "# Load Dataset" + "# Freeze BERT Parameters" ] }, { "cell_type": "code", + "execution_count": 14, "metadata": { - "id": "cwJrQFQgN_BE", + "colab": {}, "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 204 - }, - "outputId": "854f0b55-e330-4806-cc32-79643e6bd721" + "id": "wHZ0MC00RQA_", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "df = pd.read_csv(\"spamdata_v2.csv\")\n", - "df.head()" - ], - "execution_count": 3, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
labeltext
00Go until jurong point, crazy.. Available only ...
10Ok lar... Joking wif u oni...
21Free entry in 2 a wkly comp to win FA Cup fina...
30U dun say so early hor... U c already then say...
40Nah I don't think he goes to usf, he lives aro...
\n", - "
" - ], - "text/plain": [ - " label text\n", - "0 0 Go until jurong point, crazy.. Available only ...\n", - "1 0 Ok lar... Joking wif u oni...\n", - "2 1 Free entry in 2 a wkly comp to win FA Cup fina...\n", - "3 0 U dun say so early hor... U c already then say...\n", - "4 0 Nah I don't think he goes to usf, he lives aro..." - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 3 - } + "# freeze all the parameters\n", + "for param in bert.parameters():\n", + " param.requires_grad = False" ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": { - "id": "fzPPOrVQWiW5", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "outputId": "e8555c2b-a50d-4809-833f-adf3ac349a1b" + "colab_type": "text", + "id": "s7ahGBUWRi3X" }, "source": [ - "df.shape" - ], - "execution_count": 28, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(5572, 2)" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 28 - } + "# Define Model Architecture" ] }, { "cell_type": "code", + "execution_count": 15, "metadata": { - "id": "676DPU1BOPdp", + "colab": {}, "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 68 - }, - "outputId": "075808af-7b2e-4f0d-e888-e06e2e8abbf9" + "id": "b3iEtGyYRd0A", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "# check class distribution\n", - "df['label'].value_counts(normalize = True)" - ], - "execution_count": 4, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "0 0.865937\n", - "1 0.134063\n", - "Name: label, dtype: float64" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 4 - } + "class BERT_Arch(nn.Module):\n", + "\n", + " def __init__(self, bert):\n", + " \n", + " super(BERT_Arch, self).__init__()\n", + "\n", + " self.bert = bert \n", + " \n", + " # dropout layer\n", + " self.dropout = nn.Dropout(0.1)\n", + " \n", + " # relu activation function\n", + " self.relu = nn.ReLU()\n", + "\n", + " # dense layer 1\n", + " self.fc1 = nn.Linear(768,512)\n", + " \n", + " # dense layer 2 (Output layer)\n", + " self.fc2 = nn.Linear(512,2)\n", + "\n", + " #softmax activation function\n", + " self.softmax = nn.LogSoftmax(dim=1)\n", + "\n", + " #define the forward pass\n", + " def forward(self, sent_id, mask):\n", + "\n", + " #pass the inputs to the model \n", + " _, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)\n", + " \n", + " x = self.fc1(cls_hs)\n", + "\n", + " x = self.relu(x)\n", + "\n", + " x = self.dropout(x)\n", + "\n", + " # output layer\n", + " x = self.fc2(x)\n", + " \n", + " # apply softmax activation\n", + " x = self.softmax(x)\n", + "\n", + " return x" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 16, "metadata": { - "id": "MKfWnApvOoE7", - "colab_type": "text" + "colab": {}, + "colab_type": "code", + "id": "cBAJJVuJRliv", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "# Split train dataset into train, validation and test sets" + "# pass the pre-trained BERT to our define architecture\n", + "model = BERT_Arch(bert)\n", + "\n", + "# push the model to GPU\n", + "model = model.to(device)" ] }, { "cell_type": "code", + "execution_count": 17, "metadata": { - "id": "mfhSPF5jOWb7", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "taXS0IilRn9J", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "train_text, temp_text, train_labels, temp_labels = train_test_split(df['text'], df['label'], \n", - " random_state=2018, \n", - " test_size=0.3, \n", - " stratify=df['label'])\n", + "# optimizer from hugging face transformers\n", + "from transformers import AdamW\n", "\n", - "# we will use temp_text and temp_labels to create validation and test set\n", - "val_text, test_text, val_labels, test_labels = train_test_split(temp_text, temp_labels, \n", - " random_state=2018, \n", - " test_size=0.5, \n", - " stratify=temp_labels)" - ], - "execution_count": 5, - "outputs": [] + "# define the optimizer\n", + "optimizer = AdamW(model.parameters(), lr = 1e-3)" + ] }, { "cell_type": "markdown", "metadata": { - "id": "n7hsdLoCO7uB", - "colab_type": "text" + "colab_type": "text", + "id": "j9CDpoMQR_rK" }, "source": [ - "# Import BERT Model and BERT Tokenizer" + "# Find Class Weights" ] }, { "cell_type": "code", + "execution_count": 19, "metadata": { - "id": "S1kY3gZjO2RE", - "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", - "height": 164, - "referenced_widgets": [ - "983fea7c2dc74dfaba7aa60147af85d1", - "ccf5f7e5cc10493ca9c44b14fdec31dc", - "59bae99ad63d4a3a8b8d622d95f7ad07", - "689e66a8dff249449b5f0f5bbfffa037", - "49dd79a9a65044ba8345deb250ce4b24", - "47862cd626cf46619a5cc505fde02276", - "cb5f7a2a5bcb47649703cc633f2fb685", - "6d2355752eb74f348596a380d2347b73", - "0e580433bec2453da54c0ce9ee027401", - "bdfd7634b8bf42aa8794ada8d9e47173", - "73fc7587f1bb49df8a0fc87ecfac7f3c", - "4da1c15300b2468ab1a7e2df800ed39b", - "645d520e8a1c4f1fa202c6c68c5ce6af", - "3bb6b624b4ce4be788c38cb8d1936177", - "ed9f97f9d12a49aa939b7595ad3cb27c", - "88214abee8b9462f86369072d858ae9f", - "b4bef5a685954e238b52c43eefe4c9e5", - "94264f36ceb64d3881fed952bf579072", - "8cf06dad410440f78c50e3527e858905", - "edf0e4c1ae214a66abb717b20e3bffad", - "4d56a47453914de3be15d7a515e5b210", - "6831c2b733d74b31a41cb6eb971a25d7", - "58cd585c531444a5b8613c2d85bab022", - "b6423c858927455e8cbc5a953273466a" - ] + "height": 34 }, - "outputId": "4194574c-05d6-4d1d-c4c0-8dd89913ff79" + "colab_type": "code", + "id": "izY5xH5eR7Ur", + "outputId": "4682d190-bf40-4824-89af-91983ae6b174", + "vscode": { + "languageId": "python" + } }, - "source": [ - "# import BERT-base pretrained model\n", - "bert = AutoModel.from_pretrained('bert-base-uncased')\n", - "\n", - "# Load the BERT tokenizer\n", - "tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')" - ], - "execution_count": 6, "outputs": [ { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "983fea7c2dc74dfaba7aa60147af85d1", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "stream", - "text": [ - "\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0e580433bec2453da54c0ce9ee027401", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…" - ] - }, - "metadata": { - "tags": [] - } - }, - { - "output_type": "stream", - "text": [ - "\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b4bef5a685954e238b52c43eefe4c9e5", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…" - ] - }, - "metadata": { - "tags": [] - } - }, - { + "name": "stdout", "output_type": "stream", "text": [ - "\n" - ], - "name": "stdout" + "[0.57743559 3.72848948]\n" + ] } + ], + "source": [ + "from sklearn.utils.class_weight import compute_class_weight\n", + "\n", + "#compute the class weights\n", + "class_wts = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)\n", + "\n", + "print(class_wts)" ] }, { "cell_type": "code", + "execution_count": 20, "metadata": { - "id": "_zOKeOMeO-DT", + "colab": {}, "colab_type": "code", - "colab": {} - }, - "source": [ - "# sample data\n", - "text = [\"this is a bert model tutorial\", \"we will fine-tune a bert model\"]\n", - "\n", - "# encode text\n", - "sent_id = tokenizer.batch_encode_plus(text, padding=True, return_token_type_ids=False)" - ], - "execution_count": 7, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "oAH73n39PHLw", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 54 - }, - "outputId": "17b76300-71f2-464c-90b0-a5907f4f675a" + "id": "r1WvfY2vSGKi", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "# output\n", - "print(sent_id)" - ], - "execution_count": 8, - "outputs": [ - { - "output_type": "stream", - "text": [ - "{'input_ids': [[101, 2023, 2003, 1037, 14324, 2944, 14924, 4818, 102, 0], [101, 2057, 2097, 2986, 1011, 8694, 1037, 14324, 2944, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}\n" - ], - "name": "stdout" - } + "# convert class weights to tensor\n", + "weights= torch.tensor(class_wts,dtype=torch.float)\n", + "weights = weights.to(device)\n", + "\n", + "# loss function\n", + "cross_entropy = nn.NLLLoss(weight=weights) \n", + "\n", + "# number of training epochs\n", + "epochs = 10" ] }, { "cell_type": "markdown", "metadata": { - "id": "8wIYaWI_Prg8", - "colab_type": "text" + "colab_type": "text", + "id": "My4CA0qaShLq" }, "source": [ - "# Tokenization" + "# Fine-Tune BERT" ] }, { "cell_type": "code", + "execution_count": 21, "metadata": { - "id": "yKwbpeN_PMiu", + "colab": {}, "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 282 - }, - "outputId": "9f843240-6cf4-46c9-80b7-dc9ab4e03602" - }, - "source": [ - "# get length of all the messages in the train set\n", - "seq_len = [len(i.split()) for i in train_text]\n", - "\n", - "pd.Series(seq_len).hist(bins = 30)" - ], - "execution_count": 9, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 9 - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD4CAYAAADrRI2NAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAUhklEQVR4nO3df5Dcd13H8efbxhboYdIfzE0niV7QiFMbleamrYMyd8aBNEVSFZl2OpBgnYxji8XWoUFG66jMBBURRsSJpkPQyhURprEtQgw9Gf5IpamlSVtKryVIbkIqtASPVjH69o/9nG7Pu9z+SHZv83k+Zm7uu5/vZ3df++32td/97vc2kZlIkurxXf0OIEnqLYtfkipj8UtSZSx+SaqMxS9JlVnW7wAnc+GFF+bIyEjb1/v2t7/Nueeee+oDnUaDlnnQ8oKZe2XQMg9aXlg884EDB76emS9bcEJmLtmf9evXZyfuu+++jq7XT4OWedDyZpq5VwYt86DlzVw8M/BAnqRbPdQjSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVWdJf2dArI9vvaWne4R1XneYkknT6uccvSZVZtPgj4vaIeDoiDjWN/UFEfDEiHo6IT0TEiqZ174iIqYh4PCJe2zS+sYxNRcT2U/9QJEmtaGWP/0PAxjlje4FLMvNHgC8B7wCIiIuBa4AfLtf504g4KyLOAj4AXAlcDFxb5kqSemzR4s/MzwLPzBn7dGaeKBf3A6vK8mZgIjP/IzO/DEwBl5Wfqcx8KjO/A0yUuZKkHovGN3guMiliBLg7My+ZZ93fAXdm5l9FxJ8A+zPzr8q6XcAny9SNmflLZfxNwOWZeeM8t7cN2AYwPDy8fmJiou0HNTMzw9DQUMvzD04fb2neupXL287SqnYz99ug5QUz98qgZR60vLB45vHx8QOZObrQ+q7O6omIdwIngDu6uZ1mmbkT2AkwOjqaY2Njbd/G5OQk7Vxva6tn9VzXfpZWtZu53wYtL5i5VwYt86Dlhe4zd1z8EbEVeB2wIf/vbcM0sLpp2qoyxknGJUk91NHpnBGxEXg78PrMfK5p1R7gmog4JyLWAGuBfwI+D6yNiDURcTaND4D3dBddktSJRff4I+IjwBhwYUQcAW6jcRbPOcDeiIDGcf1fzsxHIuKjwKM0DgHdkJn/VW7nRuBTwFnA7Zn5yGl4PJKkRSxa/Jl57TzDu04y/13Au+YZvxe4t610kqRTzr/claTKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVsfglqTIWvyRVxuKXpMpY/JJUGYtfkipj8UtSZSx+SaqMxS9JlbH4JakyFr8kVcbil6TKLFr8EXF7RDwdEYeaxs6PiL0R8UT5fV4Zj4h4f0RMRcTDEXFp03W2lPlPRMSW0/NwJEmLaWWP/0PAxjlj24F9mbkW2FcuA1wJrC0/24APQuOFArgNuBy4DLht9sVCktRbixZ/Zn4WeGbO8GZgd1neDVzdNP7hbNgPrIiIi4DXAnsz85nMfBbYy/9/MZEk9UBk5uKTIkaAuzPzknL5m5m5oiwH8GxmroiIu4Edmfm5sm4fcCswBrwoM3+vjP8m8Hxm/uE897WNxrsFhoeH109MTLT9oGZmZhgaGmp5/sHp4y3NW7dyedtZWtVu5n4btLxg5l4ZtMyDlhcWzzw+Pn4gM0cXWr+s2wCZmRGx+KtH67e3E9gJMDo6mmNjY23fxuTkJO1cb+v2e1qad/i69rO0qt3M/TZoecHMvTJomQctL3SfudOzeo6VQziU30+X8WlgddO8VWVsoXFJUo91Wvx7gNkzc7YAdzWNv7mc3XMFcDwzjwKfAl4TEeeVD3VfU8YkST226KGeiPgIjWP0F0bEERpn5+wAPhoR1wNfAd5Ypt8LbAKmgOeAtwBk5jMR8bvA58u838nMuR8YS5J6YNHiz8xrF1i1YZ65CdywwO3cDtzeVjpJ0innX+5KUmUsfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVsfglqTIWvyRVxuKXpMpY/JJUGYtfkipj8UtSZSx+SaqMxS9JlbH4JakyFr8kVcbil6TKWPySVBmLX5Iq01XxR8SvRcQjEXEoIj4SES+KiDURcX9ETEXEnRFxdpl7Trk8VdaPnIoHIElqT8fFHxErgV8FRjPzEuAs4Brg3cB7M/MHgGeB68tVrgeeLePvLfMkST3W7aGeZcCLI2IZ8BLgKPBTwMfK+t3A1WV5c7lMWb8hIqLL+5cktSkys/MrR9wEvAt4Hvg0cBOwv+zVExGrgU9m5iURcQjYmJlHyrongcsz8+tzbnMbsA1geHh4/cTERNu5ZmZmGBoaann+wenjLc1bt3J521la1W7mfhu0vGDmXhm0zIOWFxbPPD4+fiAzRxdav6zTO46I82jsxa8Bvgn8DbCx09ublZk7gZ0Ao6OjOTY21vZtTE5O0s71tm6/p6V5h69rP0ur2s3cb4OWF8zcK4OWedDyQveZuznU89PAlzPzXzPzP4GPA68CVpRDPwCrgOmyPA2sBijrlwPf6OL+JUkd6Kb4/wW4IiJeUo7VbwAeBe4D3lDmbAHuKst7ymXK+s9kN8eZJEkd6bj4M/N+Gh/SPggcLLe1E7gVuDkipoALgF3lKruAC8r4zcD2LnJLkjrU8TF+gMy8DbhtzvBTwGXzzP134Be6ub92jbR47F6SauJf7kpSZSx+SaqMxS9JlbH4JakyFr8kVcbil6TKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVsfglqTIWvyRVxuKXpMpY/JJUGYtfkipj8UtSZboq/ohYEREfi4gvRsRjEfHjEXF+ROyNiCfK7/PK3IiI90fEVEQ8HBGXnpqHIElqR7d7/O8D/j4zfwj4UeAxYDuwLzPXAvvKZYArgbXlZxvwwS7vW5LUgY6LPyKWA68GdgFk5ncy85vAZmB3mbYbuLosbwY+nA37gRURcVHHySVJHYnM7OyKET8G7AQepbG3fwC4CZjOzBVlTgDPZuaKiLgb2JGZnyvr9gG3ZuYDc253G413BAwPD6+fmJhoO9vMzAxDQ0McnD7e0WNbyLqVy0/p7TWbzTwoBi0vmLlXBi3zoOWFxTOPj48fyMzRhdYv6+K+lwGXAm/NzPsj4n3832EdADIzI6KtV5bM3EnjBYXR0dEcGxtrO9jk5CRjY2Ns3X5P29c9mcPXtZ+lVbOZB8Wg5QUz98qgZR60vNB95m6O8R8BjmTm/eXyx2i8EBybPYRTfj9d1k8Dq5uuv6qMSZJ6qOPiz8yvAV+NiFeUoQ00DvvsAbaUsS3AXWV5D/DmcnbPFcDxzDza6f1LkjrTzaEegLcCd0TE2cBTwFtovJh8NCKuB74CvLHMvRfYBEwBz5W5kqQe66r4M/MhYL4PEDbMMzeBG7q5P0lS9/zLXUmqjMUvSZWx+CWpMt1+uKsujDT9ncEt604s+HcHh3dc1atIkirgHr8kVcbil6TKWPySVBmLX5Iq44e7bRhp8Uvf/DBW0lLmHr8kVcbil6TKWPySVBmLX5IqY/FLUmU8q+c0aPXsH0nqB/f4JakyFr8kVcbil6TKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZWx+CWpMl0Xf0ScFRH/HBF3l8trIuL+iJiKiDsj4uwyfk65PFXWj3R735Kk9p2KPf6bgMeaLr8beG9m/gDwLHB9Gb8eeLaMv7fMkyT1WFfFHxGrgKuAvyiXA/gp4GNlym7g6rK8uVymrN9Q5kuSeqjbPf4/Bt4O/He5fAHwzcw8US4fAVaW5ZXAVwHK+uNlviSphyIzO7tixOuATZn5KxExBvw6sBXYXw7nEBGrgU9m5iURcQjYmJlHyrongcsz8+tzbncbsA1geHh4/cTERNvZZmZmGBoa4uD08Y4eWz8MvxiOPT//unUrl/c2TAtmt/EgMXNvDFrmQcsLi2ceHx8/kJmjC63v5muZXwW8PiI2AS8Cvgd4H7AiIpaVvfpVwHSZPw2sBo5ExDJgOfCNuTeamTuBnQCjo6M5NjbWdrDJyUnGxsbYOkBfj3zLuhO85+D8/zkOXzfW2zAtmN3Gg8TMvTFomQctL3SfueNDPZn5jsxclZkjwDXAZzLzOuA+4A1l2hbgrrK8p1ymrP9Mdvp2Q5LUsdNxHv+twM0RMUXjGP6uMr4LuKCM3wxsPw33LUlaxCn5F7gycxKYLMtPAZfNM+ffgV84FfcnSeqcf7krSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5JqozFL0mVsfglqTIWvyRVxuKXpMpY/JJUGYtfkipj8UtSZSx+SaqMxS9JlbH4JakyFr8kVcbil6TKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZXpuPgjYnVE3BcRj0bEIxFxUxk/PyL2RsQT5fd5ZTwi4v0RMRURD0fEpafqQUiSWresi+ueAG7JzAcj4qXAgYjYC2wF9mXmjojYDmwHbgWuBNaWn8uBD5bfWsTI9ntannt4x1WnMYmkM0HHe/yZeTQzHyzL/wY8BqwENgO7y7TdwNVleTPw4WzYD6yIiIs6Ti5J6khkZvc3EjECfBa4BPiXzFxRxgN4NjNXRMTdwI7M/FxZtw+4NTMfmHNb24BtAMPDw+snJibazjMzM8PQ0BAHp493/qB6bPjFcOz57m9n3crl3d9IC2a38SAxc28MWuZBywuLZx4fHz+QmaMLre/mUA8AETEE/C3wtsz8VqPrGzIzI6KtV5bM3AnsBBgdHc2xsbG2M01OTjI2NsbWNg6R9Nst607wnoNd/+fg8HVj3Ydpwew2HiRm7o1ByzxoeaH7zF2d1RMR302j9O/IzI+X4WOzh3DK76fL+DSwuunqq8qYJKmHujmrJ4BdwGOZ+UdNq/YAW8ryFuCupvE3l7N7rgCOZ+bRTu9fktSZbo4tvAp4E3AwIh4qY78B7AA+GhHXA18B3ljW3QtsAqaA54C3dHHfkqQOdVz85UPaWGD1hnnmJ3BDp/cnSTo1/MtdSaqMxS9JlbH4JakyFr8kVcbil6TKWPySVBmLX5IqY/FLUmUsfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klQZi1+SKmPxS1JlLH5Jqkw3//SilqCR7fe0NO/wjqtOcxJJS5V7/JJUGYtfkipj8UtSZSx+SaqMxS9JlbH4Jakyns5ZKU/7lOrV8+KPiI3A+4CzgL/IzB29zqD+8QVH6r+eHuqJiLOADwBXAhcD10bExb3MIEm16/Ue/2XAVGY+BRARE8Bm4NEe51CLFtpDv2XdCba2uPd+OrX6DgJaz9zqu4127rsVvstRr0Rm9u7OIt4AbMzMXyqX3wRcnpk3Ns3ZBmwrF18BPN7BXV0IfL3LuL02aJkHLS+YuVcGLfOg5YXFM39fZr5soZVL7sPdzNwJ7OzmNiLigcwcPUWRemLQMg9aXjBzrwxa5kHLC91n7vXpnNPA6qbLq8qYJKlHel38nwfWRsSaiDgbuAbY0+MMklS1nh7qycwTEXEj8Ckap3PenpmPnIa76upQUZ8MWuZBywtm7pVByzxoeaHbw+G9/HBXktR/fmWDJFXG4pekypxRxR8RGyPi8YiYiojt/c4zn4hYHRH3RcSjEfFIRNxUxn87IqYj4qHys6nfWZtFxOGIOFiyPVDGzo+IvRHxRPl9Xr9zzoqIVzRty4ci4lsR8baltp0j4vaIeDoiDjWNzbtdo+H95fn9cERcukTy/kFEfLFk+kRErCjjIxHxfNO2/rNe5z1J5gWfBxHxjrKNH4+I1y6hzHc25T0cEQ+V8fa3c2aeET80Pix+Eng5cDbwBeDifueaJ+dFwKVl+aXAl2h8fcVvA7/e73wnyX0YuHDO2O8D28vyduDd/c55kufG14DvW2rbGXg1cClwaLHtCmwCPgkEcAVw/xLJ+xpgWVl+d1PekeZ5S2wbz/s8KP8vfgE4B1hTOuWspZB5zvr3AL/V6XY+k/b4//frIDLzO8Ds10EsKZl5NDMfLMv/BjwGrOxvqo5tBnaX5d3A1X3McjIbgCcz8yv9DjJXZn4WeGbO8ELbdTPw4WzYD6yIiIt6k7RhvryZ+enMPFEu7qfx9zlLxgLbeCGbgYnM/I/M/DIwRaNbeupkmSMigDcCH+n09s+k4l8JfLXp8hGWeKFGxAjwSuD+MnRjebt8+1I6bFIk8OmIOFC+VgNgODOPluWvAcP9ibaoa3jh/yRLeTvDwtt1EJ7jv0jjXcmsNRHxzxHxjxHxk/0KtYD5ngeDsI1/EjiWmU80jbW1nc+k4h8oETEE/C3wtsz8FvBB4PuBHwOO0ngrt5T8RGZeSuObVW+IiFc3r8zGe84ld25w+UPB1wN/U4aW+nZ+gaW6XecTEe8ETgB3lKGjwPdm5iuBm4G/jojv6Ve+OQbqeTDHtbxwR6bt7XwmFf/AfB1ERHw3jdK/IzM/DpCZxzLzvzLzv4E/pw9vL08mM6fL76eBT9DId2z2UEP5/XT/Ei7oSuDBzDwGS387Fwtt1yX7HI+IrcDrgOvKixXlcMk3yvIBGsfLf7BvIZuc5HmwZLcxQEQsA34OuHN2rJPtfCYV/0B8HUQ5PrcLeCwz/6hpvPlY7c8Ch+Zet18i4tyIeOnsMo0P8w7R2L5byrQtwF39SXhSL9g7WsrbuclC23UP8OZyds8VwPGmQ0J9E41/XOntwOsz87mm8ZdF49/gICJeDqwFnupPyhc6yfNgD3BNRJwTEWtoZP6nXuc7iZ8GvpiZR2YHOtrOvf60+jR/Er6JxlkyTwLv7HeeBTL+BI237g8DD5WfTcBfAgfL+B7gon5nbcr8chpnOnwBeGR22wIXAPuAJ4B/AM7vd9Y5uc8FvgEsbxpbUtuZxovSUeA/aRxPvn6h7UrjbJ4PlOf3QWB0ieSdonFcfPb5/Gdl7s+X58tDwIPAzyyhbbzg8wB4Z9nGjwNXLpXMZfxDwC/Pmdv2dvYrGySpMmfSoR5JUgssfkmqjMUvSZWx+CWpMha/JFXG4pekylj8klSZ/wEmDJk6BAzVDQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "tags": [], - "needs_background": "light" - } + "id": "rskLk8R_SahS", + "vscode": { + "languageId": "python" } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "OXcswEIRPvGe", - "colab_type": "code", - "colab": {} - }, - "source": [ - "max_seq_len = 25" - ], - "execution_count": 10, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "tk5S7DWaP2t6", - "colab_type": "code", - "colab": {} }, + "outputs": [], "source": [ - "# tokenize and encode sequences in the training set\n", - "tokens_train = tokenizer.batch_encode_plus(\n", - " train_text.tolist(),\n", - " max_length = max_seq_len,\n", - " pad_to_max_length=True,\n", - " truncation=True,\n", - " return_token_type_ids=False\n", - ")\n", + "# function to train the model\n", + "def train():\n", + " \n", + " model.train()\n", "\n", - "# tokenize and encode sequences in the validation set\n", - "tokens_val = tokenizer.batch_encode_plus(\n", - " val_text.tolist(),\n", - " max_length = max_seq_len,\n", - " pad_to_max_length=True,\n", - " truncation=True,\n", - " return_token_type_ids=False\n", - ")\n", + " total_loss, total_accuracy = 0, 0\n", + " \n", + " # empty list to save model predictions\n", + " total_preds=[]\n", + " \n", + " # iterate over batches\n", + " for step,batch in enumerate(train_dataloader):\n", + " \n", + " # progress update after every 50 batches.\n", + " if step % 50 == 0 and not step == 0:\n", + " print(' Batch {:>5,} of {:>5,}.'.format(step, len(train_dataloader)))\n", "\n", - "# tokenize and encode sequences in the test set\n", - "tokens_test = tokenizer.batch_encode_plus(\n", - " test_text.tolist(),\n", - " max_length = max_seq_len,\n", - " pad_to_max_length=True,\n", - " truncation=True,\n", - " return_token_type_ids=False\n", - ")" - ], - "execution_count": 11, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Wsm8bkRZQTw9", - "colab_type": "text" - }, - "source": [ - "# Convert Integer Sequences to Tensors" + " # push the batch to gpu\n", + " batch = [r.to(device) for r in batch]\n", + " \n", + " sent_id, mask, labels = batch\n", + "\n", + " # clear previously calculated gradients \n", + " model.zero_grad() \n", + "\n", + " # get model predictions for the current batch\n", + " preds = model(sent_id, mask)\n", + "\n", + " # compute the loss between actual and predicted values\n", + " loss = cross_entropy(preds, labels)\n", + "\n", + " # add on to the total loss\n", + " total_loss = total_loss + loss.item()\n", + "\n", + " # backward pass to calculate the gradients\n", + " loss.backward()\n", + "\n", + " # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", + "\n", + " # update parameters\n", + " optimizer.step()\n", + "\n", + " # model predictions are stored on GPU. So, push it to CPU\n", + " preds=preds.detach().cpu().numpy()\n", + "\n", + " # append the model predictions\n", + " total_preds.append(preds)\n", + "\n", + " # compute the training loss of the epoch\n", + " avg_loss = total_loss / len(train_dataloader)\n", + " \n", + " # predictions are in the form of (no. of batches, size of batch, no. of classes).\n", + " # reshape the predictions in form of (number of samples, no. of classes)\n", + " total_preds = np.concatenate(total_preds, axis=0)\n", + "\n", + " #returns the loss and predictions\n", + " return avg_loss, total_preds" ] }, { "cell_type": "code", + "execution_count": 22, "metadata": { - "id": "QR-lXwmzQPd6", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "yGXovFDlSxB5", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "# for train set\n", - "train_seq = torch.tensor(tokens_train['input_ids'])\n", - "train_mask = torch.tensor(tokens_train['attention_mask'])\n", - "train_y = torch.tensor(train_labels.tolist())\n", + "# function for evaluating the model\n", + "def evaluate():\n", + " \n", + " print(\"\\nEvaluating...\")\n", + " \n", + " # deactivate dropout layers\n", + " model.eval()\n", "\n", - "# for validation set\n", - "val_seq = torch.tensor(tokens_val['input_ids'])\n", - "val_mask = torch.tensor(tokens_val['attention_mask'])\n", - "val_y = torch.tensor(val_labels.tolist())\n", + " total_loss, total_accuracy = 0, 0\n", + " \n", + " # empty list to save the model predictions\n", + " total_preds = []\n", "\n", - "# for test set\n", - "test_seq = torch.tensor(tokens_test['input_ids'])\n", - "test_mask = torch.tensor(tokens_test['attention_mask'])\n", - "test_y = torch.tensor(test_labels.tolist())" - ], - "execution_count": 12, - "outputs": [] + " # iterate over batches\n", + " for step,batch in enumerate(val_dataloader):\n", + " \n", + " # Progress update every 50 batches.\n", + " if step % 50 == 0 and not step == 0:\n", + " \n", + " # Calculate elapsed time in minutes.\n", + " elapsed = format_time(time.time() - t0)\n", + " \n", + " # Report progress.\n", + " print(' Batch {:>5,} of {:>5,}.'.format(step, len(val_dataloader)))\n", + "\n", + " # push the batch to gpu\n", + " batch = [t.to(device) for t in batch]\n", + "\n", + " sent_id, mask, labels = batch\n", + "\n", + " # deactivate autograd\n", + " with torch.no_grad():\n", + " \n", + " # model predictions\n", + " preds = model(sent_id, mask)\n", + "\n", + " # compute the validation loss between actual and predicted values\n", + " loss = cross_entropy(preds,labels)\n", + "\n", + " total_loss = total_loss + loss.item()\n", + "\n", + " preds = preds.detach().cpu().numpy()\n", + "\n", + " total_preds.append(preds)\n", + "\n", + " # compute the validation loss of the epoch\n", + " avg_loss = total_loss / len(val_dataloader) \n", + "\n", + " # reshape the predictions in form of (number of samples, no. of classes)\n", + " total_preds = np.concatenate(total_preds, axis=0)\n", + "\n", + " return avg_loss, total_preds" + ] }, { "cell_type": "markdown", "metadata": { - "id": "Ov1cOBlcRLuk", - "colab_type": "text" + "colab_type": "text", + "id": "9KZEgxRRTLXG" }, "source": [ - "# Create DataLoaders" + "# Start Model Training" ] }, { "cell_type": "code", + "execution_count": 23, "metadata": { - "id": "qUy9JKFYQYLp", - "colab_type": "code", - "colab": {} + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "colab_type": "code", + "id": "k1USGTntS3TS", + "outputId": "6c03e17f-476c-4741-eae5-c5722cb5d413", + "vscode": { + "languageId": "python" + } }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " Epoch 1 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.526\n", + "Validation Loss: 0.656\n", + "\n", + " Epoch 2 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.345\n", + "Validation Loss: 0.231\n", + "\n", + " Epoch 3 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.344\n", + "Validation Loss: 0.194\n", + "\n", + " Epoch 4 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.223\n", + "Validation Loss: 0.171\n", + "\n", + " Epoch 5 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.219\n", + "Validation Loss: 0.178\n", + "\n", + " Epoch 6 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.215\n", + "Validation Loss: 0.180\n", + "\n", + " Epoch 7 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.247\n", + "Validation Loss: 0.262\n", + "\n", + " Epoch 8 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.224\n", + "Validation Loss: 0.217\n", + "\n", + " Epoch 9 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.217\n", + "Validation Loss: 0.148\n", + "\n", + " Epoch 10 / 10\n", + " Batch 50 of 122.\n", + " Batch 100 of 122.\n", + "\n", + "Evaluating...\n", + "\n", + "Training Loss: 0.231\n", + "Validation Loss: 0.639\n" + ] + } + ], "source": [ - "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n", - "\n", - "#define a batch size\n", - "batch_size = 32\n", - "\n", - "# wrap tensors\n", - "train_data = TensorDataset(train_seq, train_mask, train_y)\n", - "\n", - "# sampler for sampling the data during training\n", - "train_sampler = RandomSampler(train_data)\n", - "\n", - "# dataLoader for train set\n", - "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n", - "\n", - "# wrap tensors\n", - "val_data = TensorDataset(val_seq, val_mask, val_y)\n", + "# set initial loss to infinite\n", + "best_valid_loss = float('inf')\n", "\n", - "# sampler for sampling the data during training\n", - "val_sampler = SequentialSampler(val_data)\n", + "# empty lists to store training and validation loss of each epoch\n", + "train_losses=[]\n", + "valid_losses=[]\n", "\n", - "# dataLoader for validation set\n", - "val_dataloader = DataLoader(val_data, sampler = val_sampler, batch_size=batch_size)" - ], - "execution_count": 13, - "outputs": [] + "#for each epoch\n", + "for epoch in range(epochs):\n", + " \n", + " print('\\n Epoch {:} / {:}'.format(epoch + 1, epochs))\n", + " \n", + " #train model\n", + " train_loss, _ = train()\n", + " \n", + " #evaluate model\n", + " valid_loss, _ = evaluate()\n", + " \n", + " #save the best model\n", + " if valid_loss < best_valid_loss:\n", + " best_valid_loss = valid_loss\n", + " torch.save(model.state_dict(), 'saved_weights.pt')\n", + " \n", + " # append training and validation loss\n", + " train_losses.append(train_loss)\n", + " valid_losses.append(valid_loss)\n", + " \n", + " print(f'\\nTraining Loss: {train_loss:.3f}')\n", + " print(f'Validation Loss: {valid_loss:.3f}')" + ] }, { "cell_type": "markdown", "metadata": { - "id": "K2HZc5ZYRV28", - "colab_type": "text" + "colab_type": "text", + "id": "_yrhUc9kTI5a" }, "source": [ - "# Freeze BERT Parameters" + "# Load Saved Model" ] }, { "cell_type": "code", + "execution_count": 24, "metadata": { - "id": "wHZ0MC00RQA_", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, "colab_type": "code", - "colab": {} + "id": "OacxUyizS8d1", + "outputId": "c8b951c2-1f74-4a13-db65-8acd077995e5", + "vscode": { + "languageId": "python" + } }, - "source": [ - "# freeze all the parameters\n", - "for param in bert.parameters():\n", - " param.requires_grad = False" + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } ], - "execution_count": 14, - "outputs": [] + "source": [ + "#load weights of best model\n", + "path = 'saved_weights.pt'\n", + "model.load_state_dict(torch.load(path))" + ] }, { "cell_type": "markdown", "metadata": { - "id": "s7ahGBUWRi3X", - "colab_type": "text" + "colab_type": "text", + "id": "x4SVftkkTZXA" }, "source": [ - "# Define Model Architecture" + "# Get Predictions for Test Data" ] }, { "cell_type": "code", + "execution_count": 25, "metadata": { - "id": "b3iEtGyYRd0A", + "colab": {}, "colab_type": "code", - "colab": {} + "id": "NZl0SZmFTRQA", + "vscode": { + "languageId": "python" + } }, + "outputs": [], "source": [ - "class BERT_Arch(nn.Module):\n", - "\n", - " def __init__(self, bert):\n", - " \n", - " super(BERT_Arch, self).__init__()\n", - "\n", - " self.bert = bert \n", - " \n", - " # dropout layer\n", - " self.dropout = nn.Dropout(0.1)\n", - " \n", - " # relu activation function\n", - " self.relu = nn.ReLU()\n", - "\n", - " # dense layer 1\n", - " self.fc1 = nn.Linear(768,512)\n", - " \n", - " # dense layer 2 (Output layer)\n", - " self.fc2 = nn.Linear(512,2)\n", - "\n", - " #softmax activation function\n", - " self.softmax = nn.LogSoftmax(dim=1)\n", - "\n", - " #define the forward pass\n", - " def forward(self, sent_id, mask):\n", - "\n", - " #pass the inputs to the model \n", - " _, cls_hs = self.bert(sent_id, attention_mask=mask)\n", - " \n", - " x = self.fc1(cls_hs)\n", - "\n", - " x = self.relu(x)\n", - "\n", - " x = self.dropout(x)\n", - "\n", - " # output layer\n", - " x = self.fc2(x)\n", - " \n", - " # apply softmax activation\n", - " x = self.softmax(x)\n", - "\n", - " return x" - ], - "execution_count": 15, - "outputs": [] + "# get predictions for test data\n", + "with torch.no_grad():\n", + " preds = model(test_seq.to(device), test_mask.to(device))\n", + " preds = preds.detach().cpu().numpy()" + ] }, { "cell_type": "code", + "execution_count": 26, "metadata": { - "id": "cBAJJVuJRliv", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170 + }, "colab_type": "code", - "colab": {} + "id": "Ms1ObHZxTYSI", + "outputId": "47d01595-e519-4a58-8f2e-75596ea1512d", + "vscode": { + "languageId": "python" + } }, - "source": [ - "# pass the pre-trained BERT to our define architecture\n", - "model = BERT_Arch(bert)\n", - "\n", - "# push the model to GPU\n", - "model = model.to(device)" - ], - "execution_count": 16, - "outputs": [] - }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.99 0.98 0.98 724\n", + " 1 0.88 0.92 0.90 112\n", + "\n", + " accuracy 0.97 836\n", + " macro avg 0.93 0.95 0.94 836\n", + "weighted avg 0.97 0.97 0.97 836\n", + "\n" + ] + } + ], + "source": [ + "# model's performance\n", + "preds = np.argmax(preds, axis = 1)\n", + "print(classification_report(test_y, preds))" + ] + }, { "cell_type": "code", + "execution_count": 27, "metadata": { - "id": "taXS0IilRn9J", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 142 + }, "colab_type": "code", - "colab": {} + "id": "YqzLS7rHTp4T", + "outputId": "d3abc432-5ad0-41e5-cfc2-d1d1192f5672", + "vscode": { + "languageId": "python" + } }, - "source": [ - "# optimizer from hugging face transformers\n", - "from transformers import AdamW\n", - "\n", - "# define the optimizer\n", - "optimizer = AdamW(model.parameters(), lr = 1e-3)" + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
col_001
row_0
071014
19103
\n", + "
" + ], + "text/plain": [ + "col_0 0 1\n", + "row_0 \n", + "0 710 14\n", + "1 9 103" + ] + }, + "execution_count": 27, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } ], - "execution_count": 17, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "j9CDpoMQR_rK", - "colab_type": "text" - }, "source": [ - "# Find Class Weights" + "# confusion matrix\n", + "pd.crosstab(test_y, preds)" ] }, { "cell_type": "code", + "execution_count": null, "metadata": { - "id": "izY5xH5eR7Ur", + "colab": {}, "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 + "id": "jpX1uTwjUPY6", + "vscode": { + "languageId": "python" + } + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "authorship_tag": "ABX9TyOt9x7x5Cm/ENCEI4+c+LvL", + "include_colab_link": true, + "name": "Fine-Tuning BERT for Spam Classification.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "0e580433bec2453da54c0ce9ee027401": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_73fc7587f1bb49df8a0fc87ecfac7f3c", + "IPY_MODEL_4da1c15300b2468ab1a7e2df800ed39b" + ], + "layout": "IPY_MODEL_bdfd7634b8bf42aa8794ada8d9e47173" + } + }, + "3bb6b624b4ce4be788c38cb8d1936177": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "47862cd626cf46619a5cc505fde02276": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "49dd79a9a65044ba8345deb250ce4b24": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "initial" + } + }, + "4d56a47453914de3be15d7a515e5b210": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "initial" + } + }, + "4da1c15300b2468ab1a7e2df800ed39b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_88214abee8b9462f86369072d858ae9f", + "placeholder": "​", + "style": "IPY_MODEL_ed9f97f9d12a49aa939b7595ad3cb27c", + "value": " 440M/440M [00:11<00:00, 37.5MB/s]" + } + }, + "58cd585c531444a5b8613c2d85bab022": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "59bae99ad63d4a3a8b8d622d95f7ad07": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "Downloading: 100%", + "description_tooltip": null, + "layout": "IPY_MODEL_47862cd626cf46619a5cc505fde02276", + "max": 433, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_49dd79a9a65044ba8345deb250ce4b24", + "value": 433 + } + }, + "645d520e8a1c4f1fa202c6c68c5ce6af": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "initial" + } + }, + "6831c2b733d74b31a41cb6eb971a25d7": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "689e66a8dff249449b5f0f5bbfffa037": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6d2355752eb74f348596a380d2347b73", + "placeholder": "​", + "style": "IPY_MODEL_cb5f7a2a5bcb47649703cc633f2fb685", + "value": " 433/433 [00:00<00:00, 1.98kB/s]" + } + }, + "6d2355752eb74f348596a380d2347b73": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "73fc7587f1bb49df8a0fc87ecfac7f3c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "Downloading: 100%", + "description_tooltip": null, + "layout": "IPY_MODEL_3bb6b624b4ce4be788c38cb8d1936177", + "max": 440473133, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_645d520e8a1c4f1fa202c6c68c5ce6af", + "value": 440473133 + } + }, + "88214abee8b9462f86369072d858ae9f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "8cf06dad410440f78c50e3527e858905": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "Downloading: 100%", + "description_tooltip": null, + "layout": "IPY_MODEL_6831c2b733d74b31a41cb6eb971a25d7", + "max": 231508, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4d56a47453914de3be15d7a515e5b210", + "value": 231508 + } }, - "outputId": "4682d190-bf40-4824-89af-91983ae6b174" - }, - "source": [ - "from sklearn.utils.class_weight import compute_class_weight\n", - "\n", - "#compute the class weights\n", - "class_wts = compute_class_weight('balanced', np.unique(train_labels), train_labels)\n", - "\n", - "print(class_wts)" - ], - "execution_count": 19, - "outputs": [ - { - "output_type": "stream", - "text": [ - "[0.57743559 3.72848948]\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "r1WvfY2vSGKi", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# convert class weights to tensor\n", - "weights= torch.tensor(class_wts,dtype=torch.float)\n", - "weights = weights.to(device)\n", - "\n", - "# loss function\n", - "cross_entropy = nn.NLLLoss(weight=weights) \n", - "\n", - "# number of training epochs\n", - "epochs = 10" - ], - "execution_count": 20, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "My4CA0qaShLq", - "colab_type": "text" - }, - "source": [ - "# Fine-Tune BERT" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "rskLk8R_SahS", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# function to train the model\n", - "def train():\n", - " \n", - " model.train()\n", - "\n", - " total_loss, total_accuracy = 0, 0\n", - " \n", - " # empty list to save model predictions\n", - " total_preds=[]\n", - " \n", - " # iterate over batches\n", - " for step,batch in enumerate(train_dataloader):\n", - " \n", - " # progress update after every 50 batches.\n", - " if step % 50 == 0 and not step == 0:\n", - " print(' Batch {:>5,} of {:>5,}.'.format(step, len(train_dataloader)))\n", - "\n", - " # push the batch to gpu\n", - " batch = [r.to(device) for r in batch]\n", - " \n", - " sent_id, mask, labels = batch\n", - "\n", - " # clear previously calculated gradients \n", - " model.zero_grad() \n", - "\n", - " # get model predictions for the current batch\n", - " preds = model(sent_id, mask)\n", - "\n", - " # compute the loss between actual and predicted values\n", - " loss = cross_entropy(preds, labels)\n", - "\n", - " # add on to the total loss\n", - " total_loss = total_loss + loss.item()\n", - "\n", - " # backward pass to calculate the gradients\n", - " loss.backward()\n", - "\n", - " # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem\n", - " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", - "\n", - " # update parameters\n", - " optimizer.step()\n", - "\n", - " # model predictions are stored on GPU. So, push it to CPU\n", - " preds=preds.detach().cpu().numpy()\n", - "\n", - " # append the model predictions\n", - " total_preds.append(preds)\n", - "\n", - " # compute the training loss of the epoch\n", - " avg_loss = total_loss / len(train_dataloader)\n", - " \n", - " # predictions are in the form of (no. of batches, size of batch, no. of classes).\n", - " # reshape the predictions in form of (number of samples, no. of classes)\n", - " total_preds = np.concatenate(total_preds, axis=0)\n", - "\n", - " #returns the loss and predictions\n", - " return avg_loss, total_preds" - ], - "execution_count": 21, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "yGXovFDlSxB5", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# function for evaluating the model\n", - "def evaluate():\n", - " \n", - " print(\"\\nEvaluating...\")\n", - " \n", - " # deactivate dropout layers\n", - " model.eval()\n", - "\n", - " total_loss, total_accuracy = 0, 0\n", - " \n", - " # empty list to save the model predictions\n", - " total_preds = []\n", - "\n", - " # iterate over batches\n", - " for step,batch in enumerate(val_dataloader):\n", - " \n", - " # Progress update every 50 batches.\n", - " if step % 50 == 0 and not step == 0:\n", - " \n", - " # Calculate elapsed time in minutes.\n", - " elapsed = format_time(time.time() - t0)\n", - " \n", - " # Report progress.\n", - " print(' Batch {:>5,} of {:>5,}.'.format(step, len(val_dataloader)))\n", - "\n", - " # push the batch to gpu\n", - " batch = [t.to(device) for t in batch]\n", - "\n", - " sent_id, mask, labels = batch\n", - "\n", - " # deactivate autograd\n", - " with torch.no_grad():\n", - " \n", - " # model predictions\n", - " preds = model(sent_id, mask)\n", - "\n", - " # compute the validation loss between actual and predicted values\n", - " loss = cross_entropy(preds,labels)\n", - "\n", - " total_loss = total_loss + loss.item()\n", - "\n", - " preds = preds.detach().cpu().numpy()\n", - "\n", - " total_preds.append(preds)\n", - "\n", - " # compute the validation loss of the epoch\n", - " avg_loss = total_loss / len(val_dataloader) \n", - "\n", - " # reshape the predictions in form of (number of samples, no. of classes)\n", - " total_preds = np.concatenate(total_preds, axis=0)\n", - "\n", - " return avg_loss, total_preds" - ], - "execution_count": 22, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9KZEgxRRTLXG", - "colab_type": "text" - }, - "source": [ - "# Start Model Training" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "k1USGTntS3TS", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 + "94264f36ceb64d3881fed952bf579072": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "983fea7c2dc74dfaba7aa60147af85d1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_59bae99ad63d4a3a8b8d622d95f7ad07", + "IPY_MODEL_689e66a8dff249449b5f0f5bbfffa037" + ], + "layout": "IPY_MODEL_ccf5f7e5cc10493ca9c44b14fdec31dc" + } }, - "outputId": "6c03e17f-476c-4741-eae5-c5722cb5d413" - }, - "source": [ - "# set initial loss to infinite\n", - "best_valid_loss = float('inf')\n", - "\n", - "# empty lists to store training and validation loss of each epoch\n", - "train_losses=[]\n", - "valid_losses=[]\n", - "\n", - "#for each epoch\n", - "for epoch in range(epochs):\n", - " \n", - " print('\\n Epoch {:} / {:}'.format(epoch + 1, epochs))\n", - " \n", - " #train model\n", - " train_loss, _ = train()\n", - " \n", - " #evaluate model\n", - " valid_loss, _ = evaluate()\n", - " \n", - " #save the best model\n", - " if valid_loss < best_valid_loss:\n", - " best_valid_loss = valid_loss\n", - " torch.save(model.state_dict(), 'saved_weights.pt')\n", - " \n", - " # append training and validation loss\n", - " train_losses.append(train_loss)\n", - " valid_losses.append(valid_loss)\n", - " \n", - " print(f'\\nTraining Loss: {train_loss:.3f}')\n", - " print(f'Validation Loss: {valid_loss:.3f}')" - ], - "execution_count": 23, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\n", - " Epoch 1 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.526\n", - "Validation Loss: 0.656\n", - "\n", - " Epoch 2 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.345\n", - "Validation Loss: 0.231\n", - "\n", - " Epoch 3 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.344\n", - "Validation Loss: 0.194\n", - "\n", - " Epoch 4 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.223\n", - "Validation Loss: 0.171\n", - "\n", - " Epoch 5 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.219\n", - "Validation Loss: 0.178\n", - "\n", - " Epoch 6 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.215\n", - "Validation Loss: 0.180\n", - "\n", - " Epoch 7 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.247\n", - "Validation Loss: 0.262\n", - "\n", - " Epoch 8 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.224\n", - "Validation Loss: 0.217\n", - "\n", - " Epoch 9 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.217\n", - "Validation Loss: 0.148\n", - "\n", - " Epoch 10 / 10\n", - " Batch 50 of 122.\n", - " Batch 100 of 122.\n", - "\n", - "Evaluating...\n", - "\n", - "Training Loss: 0.231\n", - "Validation Loss: 0.639\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_yrhUc9kTI5a", - "colab_type": "text" - }, - "source": [ - "# Load Saved Model" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "OacxUyizS8d1", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 34 + "b4bef5a685954e238b52c43eefe4c9e5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_8cf06dad410440f78c50e3527e858905", + "IPY_MODEL_edf0e4c1ae214a66abb717b20e3bffad" + ], + "layout": "IPY_MODEL_94264f36ceb64d3881fed952bf579072" + } + }, + "b6423c858927455e8cbc5a953273466a": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bdfd7634b8bf42aa8794ada8d9e47173": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } }, - "outputId": "c8b951c2-1f74-4a13-db65-8acd077995e5" - }, - "source": [ - "#load weights of best model\n", - "path = 'saved_weights.pt'\n", - "model.load_state_dict(torch.load(path))" - ], - "execution_count": 24, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 24 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "x4SVftkkTZXA", - "colab_type": "text" - }, - "source": [ - "# Get Predictions for Test Data" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "NZl0SZmFTRQA", - "colab_type": "code", - "colab": {} - }, - "source": [ - "# get predictions for test data\n", - "with torch.no_grad():\n", - " preds = model(test_seq.to(device), test_mask.to(device))\n", - " preds = preds.detach().cpu().numpy()" - ], - "execution_count": 25, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "Ms1ObHZxTYSI", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 170 + "cb5f7a2a5bcb47649703cc633f2fb685": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } }, - "outputId": "47d01595-e519-4a58-8f2e-75596ea1512d" - }, - "source": [ - "# model's performance\n", - "preds = np.argmax(preds, axis = 1)\n", - "print(classification_report(test_y, preds))" - ], - "execution_count": 26, - "outputs": [ - { - "output_type": "stream", - "text": [ - " precision recall f1-score support\n", - "\n", - " 0 0.99 0.98 0.98 724\n", - " 1 0.88 0.92 0.90 112\n", - "\n", - " accuracy 0.97 836\n", - " macro avg 0.93 0.95 0.94 836\n", - "weighted avg 0.97 0.97 0.97 836\n", - "\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "YqzLS7rHTp4T", - "colab_type": "code", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 142 + "ccf5f7e5cc10493ca9c44b14fdec31dc": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } }, - "outputId": "d3abc432-5ad0-41e5-cfc2-d1d1192f5672" - }, - "source": [ - "# confusion matrix\n", - "pd.crosstab(test_y, preds)" - ], - "execution_count": 27, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
col_001
row_0
071014
19103
\n", - "
" - ], - "text/plain": [ - "col_0 0 1\n", - "row_0 \n", - "0 710 14\n", - "1 9 103" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 27 + "ed9f97f9d12a49aa939b7595ad3cb27c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "edf0e4c1ae214a66abb717b20e3bffad": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b6423c858927455e8cbc5a953273466a", + "placeholder": "​", + "style": "IPY_MODEL_58cd585c531444a5b8613c2d85bab022", + "value": " 232k/232k [00:39<00:00, 5.82kB/s]" + } } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "jpX1uTwjUPY6", - "colab_type": "code", - "colab": {} - }, - "source": [ - "" - ], - "execution_count": null, - "outputs": [] + } } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +}