Skip to content

Commit

Permalink
Add skorch
Browse files Browse the repository at this point in the history
  • Loading branch information
baniasbaabe committed Dec 24, 2023
1 parent 56997dd commit dc6d256
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions book/machinelearning/modeltraining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,87 @@
"results = tuner.fit()\n",
"print(results.get_best_result(metric=\"mean_accuracy\", mode=\"max\").config)"
]
},
{
"cell_type": "markdown",
"id": "62334c73",
"metadata": {},
"source": [
"## Use PyTorch with scikit-learn API with `skorch`"
]
},
{
"cell_type": "markdown",
"id": "4ac45c78",
"metadata": {},
"source": [
"PyTorch and scikit-learn are one of the most popular libraries for ML/DL.\n",
"\n",
"So, why not combine PyTorch with scikit-learn?\n",
"\n",
"Try `skorch`!\n",
"\n",
"`skorch` is a high-level library for PyTorch that provides a scikit-learn-compatible neural network module.\n",
"\n",
"It allows you to use the simple scikit-learn interface for PyTorch.\n",
"\n",
"Therefore you can integrate PyTorch models into scikit-learn workflows.\n",
"\n",
"See below for an example."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09a00fe5",
"metadata": {},
"outputs": [],
"source": [
"!pip install skorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad9ed770",
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"from skorch import NeuralNetClassifier\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"class MyModule(nn.Module):\n",
" def __init__(self, num_units=10, nonlin=nn.ReLU()):\n",
" super().__init__()\n",
"\n",
" self.dense = nn.Linear(20, num_units)\n",
" self.nonlin = nonlin\n",
" self.output = nn.Linear(num_units, 2)\n",
" self.softmax = nn.Softmax(dim=-1)\n",
"\n",
" def forward(self, X, **kwargs):\n",
" X = self.nonlin(self.dense(X))\n",
" X = self.dropout(X)\n",
" X = self.softmax(self.output(X))\n",
" return X\n",
"\n",
"net = NeuralNetClassifier(\n",
" MyModule,\n",
" max_epochs=10,\n",
" lr=0.1,\n",
" iterator_train__shuffle=True,\n",
")\n",
"\n",
"pipe = Pipeline([\n",
" ('scale', StandardScaler()),\n",
" ('net', net),\n",
"])\n",
"\n",
"pipe.fit(X, y)\n",
"y_proba = pipe.predict_proba(X)"
]
}
],
"metadata": {
Expand Down

0 comments on commit dc6d256

Please sign in to comment.