From dc6d256edb26a61a34a2091638629e9219398d0c Mon Sep 17 00:00:00 2001 From: Banias Baabe Date: Sun, 24 Dec 2023 13:11:59 +0100 Subject: [PATCH] Add skorch --- book/machinelearning/modeltraining.ipynb | 81 ++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/book/machinelearning/modeltraining.ipynb b/book/machinelearning/modeltraining.ipynb index 21bc867..bbf373f 100644 --- a/book/machinelearning/modeltraining.ipynb +++ b/book/machinelearning/modeltraining.ipynb @@ -1644,6 +1644,87 @@ "results = tuner.fit()\n", "print(results.get_best_result(metric=\"mean_accuracy\", mode=\"max\").config)" ] + }, + { + "cell_type": "markdown", + "id": "62334c73", + "metadata": {}, + "source": [ + "## Use PyTorch with scikit-learn API with `skorch`" + ] + }, + { + "cell_type": "markdown", + "id": "4ac45c78", + "metadata": {}, + "source": [ + "PyTorch and scikit-learn are one of the most popular libraries for ML/DL.\n", + "\n", + "So, why not combine PyTorch with scikit-learn?\n", + "\n", + "Try `skorch`!\n", + "\n", + "`skorch` is a high-level library for PyTorch that provides a scikit-learn-compatible neural network module.\n", + "\n", + "It allows you to use the simple scikit-learn interface for PyTorch.\n", + "\n", + "Therefore you can integrate PyTorch models into scikit-learn workflows.\n", + "\n", + "See below for an example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09a00fe5", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install skorch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad9ed770", + "metadata": {}, + "outputs": [], + "source": [ + "from torch import nn\n", + "from skorch import NeuralNetClassifier\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "class MyModule(nn.Module):\n", + " def __init__(self, num_units=10, nonlin=nn.ReLU()):\n", + " super().__init__()\n", + "\n", + " self.dense = nn.Linear(20, num_units)\n", + " self.nonlin = nonlin\n", + " self.output = nn.Linear(num_units, 2)\n", + " self.softmax = nn.Softmax(dim=-1)\n", + "\n", + " def forward(self, X, **kwargs):\n", + " X = self.nonlin(self.dense(X))\n", + " X = self.dropout(X)\n", + " X = self.softmax(self.output(X))\n", + " return X\n", + "\n", + "net = NeuralNetClassifier(\n", + " MyModule,\n", + " max_epochs=10,\n", + " lr=0.1,\n", + " iterator_train__shuffle=True,\n", + ")\n", + "\n", + "pipe = Pipeline([\n", + " ('scale', StandardScaler()),\n", + " ('net', net),\n", + "])\n", + "\n", + "pipe.fit(X, y)\n", + "y_proba = pipe.predict_proba(X)" + ] } ], "metadata": {