diff --git a/experimentation/Diabetes Ridge Regression Training.ipynb b/experimentation/Diabetes Ridge Regression Training.ipynb index fa192115..56e25ff3 100644 --- a/experimentation/Diabetes Ridge Regression Training.ipynb +++ b/experimentation/Diabetes Ridge Regression Training.ipynb @@ -50,240 +50,43 @@ ] }, { - "cell_type": "code", - "execution_count": 7, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(442, 10)\n" - ] - } - ], "source": [ - "print(df.shape)" + "## Split Data into Training and Validation Sets" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "<div>\n", - "<style scoped>\n", - " .dataframe tbody tr th:only-of-type {\n", - " vertical-align: middle;\n", - " }\n", - "\n", - " .dataframe tbody tr th {\n", - " vertical-align: top;\n", - " }\n", - "\n", - " .dataframe thead th {\n", - " text-align: right;\n", - " }\n", - "</style>\n", - "<table border=\"1\" class=\"dataframe\">\n", - " <thead>\n", - " <tr style=\"text-align: right;\">\n", - " <th></th>\n", - " <th>age</th>\n", - " <th>sex</th>\n", - " <th>bmi</th>\n", - " <th>bp</th>\n", - " <th>s1</th>\n", - " <th>s2</th>\n", - " <th>s3</th>\n", - " <th>s4</th>\n", - " <th>s5</th>\n", - " <th>s6</th>\n", - " <th>Y</th>\n", - " </tr>\n", - " </thead>\n", - " <tbody>\n", - " <tr>\n", - " <td>count</td>\n", - " <td>4.420000e+02</td>\n", - " <td>4.420000e+02</td>\n", - " <td>4.420000e+02</td>\n", - " <td>4.420000e+02</td>\n", - " <td>4.420000e+02</td>\n", - " <td>4.420000e+02</td>\n", - " <td>4.420000e+02</td>\n", - " <td>4.420000e+02</td>\n", - " <td>4.420000e+02</td>\n", - " <td>4.420000e+02</td>\n", - " <td>442.000000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>mean</td>\n", - " <td>-3.634285e-16</td>\n", - " <td>1.308343e-16</td>\n", - " <td>-8.045349e-16</td>\n", - " <td>1.281655e-16</td>\n", - " <td>-8.835316e-17</td>\n", - " <td>1.327024e-16</td>\n", - " <td>-4.574646e-16</td>\n", - " <td>3.777301e-16</td>\n", - " <td>-3.830854e-16</td>\n", - " <td>-3.412882e-16</td>\n", - " <td>152.133484</td>\n", - " </tr>\n", - " <tr>\n", - " <td>std</td>\n", - " <td>4.761905e-02</td>\n", - " <td>4.761905e-02</td>\n", - " <td>4.761905e-02</td>\n", - " <td>4.761905e-02</td>\n", - " <td>4.761905e-02</td>\n", - " <td>4.761905e-02</td>\n", - " <td>4.761905e-02</td>\n", - " <td>4.761905e-02</td>\n", - " <td>4.761905e-02</td>\n", - " <td>4.761905e-02</td>\n", - " <td>77.093005</td>\n", - " </tr>\n", - " <tr>\n", - " <td>min</td>\n", - " <td>-1.072256e-01</td>\n", - " <td>-4.464164e-02</td>\n", - " <td>-9.027530e-02</td>\n", - " <td>-1.123996e-01</td>\n", - " <td>-1.267807e-01</td>\n", - " <td>-1.156131e-01</td>\n", - " <td>-1.023071e-01</td>\n", - " <td>-7.639450e-02</td>\n", - " <td>-1.260974e-01</td>\n", - " <td>-1.377672e-01</td>\n", - " <td>25.000000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>25%</td>\n", - " <td>-3.729927e-02</td>\n", - " <td>-4.464164e-02</td>\n", - " <td>-3.422907e-02</td>\n", - " <td>-3.665645e-02</td>\n", - " <td>-3.424784e-02</td>\n", - " <td>-3.035840e-02</td>\n", - " <td>-3.511716e-02</td>\n", - " <td>-3.949338e-02</td>\n", - " <td>-3.324879e-02</td>\n", - " <td>-3.317903e-02</td>\n", - " <td>87.000000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>50%</td>\n", - " <td>5.383060e-03</td>\n", - " <td>-4.464164e-02</td>\n", - " <td>-7.283766e-03</td>\n", - " <td>-5.670611e-03</td>\n", - " <td>-4.320866e-03</td>\n", - " <td>-3.819065e-03</td>\n", - " <td>-6.584468e-03</td>\n", - " <td>-2.592262e-03</td>\n", - " <td>-1.947634e-03</td>\n", - " <td>-1.077698e-03</td>\n", - " <td>140.500000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>75%</td>\n", - " <td>3.807591e-02</td>\n", - " <td>5.068012e-02</td>\n", - " <td>3.124802e-02</td>\n", - " <td>3.564384e-02</td>\n", - " <td>2.835801e-02</td>\n", - " <td>2.984439e-02</td>\n", - " <td>2.931150e-02</td>\n", - " <td>3.430886e-02</td>\n", - " <td>3.243323e-02</td>\n", - " <td>2.791705e-02</td>\n", - " <td>211.500000</td>\n", - " </tr>\n", - " <tr>\n", - " <td>max</td>\n", - " <td>1.107267e-01</td>\n", - " <td>5.068012e-02</td>\n", - " <td>1.705552e-01</td>\n", - " <td>1.320442e-01</td>\n", - " <td>1.539137e-01</td>\n", - " <td>1.987880e-01</td>\n", - " <td>1.811791e-01</td>\n", - " <td>1.852344e-01</td>\n", - " <td>1.335990e-01</td>\n", - " <td>1.356118e-01</td>\n", - " <td>346.000000</td>\n", - " </tr>\n", - " </tbody>\n", - "</table>\n", - "</div>" - ], - "text/plain": [ - " age sex bmi bp s1 \\\n", - "count 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 \n", - "mean -3.634285e-16 1.308343e-16 -8.045349e-16 1.281655e-16 -8.835316e-17 \n", - "std 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 \n", - "min -1.072256e-01 -4.464164e-02 -9.027530e-02 -1.123996e-01 -1.267807e-01 \n", - "25% -3.729927e-02 -4.464164e-02 -3.422907e-02 -3.665645e-02 -3.424784e-02 \n", - "50% 5.383060e-03 -4.464164e-02 -7.283766e-03 -5.670611e-03 -4.320866e-03 \n", - "75% 3.807591e-02 5.068012e-02 3.124802e-02 3.564384e-02 2.835801e-02 \n", - "max 1.107267e-01 5.068012e-02 1.705552e-01 1.320442e-01 1.539137e-01 \n", - "\n", - " s2 s3 s4 s5 s6 \\\n", - "count 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 \n", - "mean 1.327024e-16 -4.574646e-16 3.777301e-16 -3.830854e-16 -3.412882e-16 \n", - "std 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 \n", - "min -1.156131e-01 -1.023071e-01 -7.639450e-02 -1.260974e-01 -1.377672e-01 \n", - "25% -3.035840e-02 -3.511716e-02 -3.949338e-02 -3.324879e-02 -3.317903e-02 \n", - "50% -3.819065e-03 -6.584468e-03 -2.592262e-03 -1.947634e-03 -1.077698e-03 \n", - "75% 2.984439e-02 2.931150e-02 3.430886e-02 3.243323e-02 2.791705e-02 \n", - "max 1.987880e-01 1.811791e-01 1.852344e-01 1.335990e-01 1.356118e-01 \n", - "\n", - " Y \n", - "count 442.000000 \n", - "mean 152.133484 \n", - "std 77.093005 \n", - "min 25.000000 \n", - "25% 87.000000 \n", - "50% 140.500000 \n", - "75% 211.500000 \n", - "max 346.000000 " - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "# All data in a single dataframe\n", - "df.describe()" + "def = split_data(df):" ] - }, - { + }, + { "cell_type": "markdown", "metadata": {}, "source": [ - "## Split Data into Training and Validation Sets" + "## Split the dataframe into test and train data" ] }, - { + { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ + "def = split_data(df):\n", "X = df.drop('Y', axis=1).values\n", "y = df['Y'].values\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, random_state=0)\n", "data = {\"train\": {\"X\": X_train, \"y\": y_train},\n", - " \"test\": {\"X\": X_test, \"y\": y_test}}" + " \"test\": {\"X\": X_test, \"y\": y_test}}"\n, + " return data" ] }, { @@ -310,7 +113,7 @@ "output_type": "execute_result" } ], - "source": [ + "source": [ "# experiment parameters\n", "args = {\n", " \"alpha\": 0.5\n",