Skip to content

Commit

Permalink
add random seed parameter for fitting decision trees
Browse files Browse the repository at this point in the history
  • Loading branch information
qualiaMachine authored Dec 5, 2024
1 parent 4d51000 commit 268d20b
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions _episodes/03-classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ We'll first apply a decision tree classifier to the data. Decisions trees are co


Training and using a decision tree in Scikit-Learn is straightforward:

**Note**: Decision trees sometimes use randomness when selecting features to split on, especially when working with data where splits could have equal information gain or in ensemble methods (like Random Forests) where random feature subsets are selected. Setting random_state ensures that this randomness is reproducible.
~~~
from sklearn.tree import DecisionTreeClassifier, plot_tree
clf = DecisionTreeClassifier(max_depth=2)
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X_train, y_train)
clf.predict(X_test)
Expand Down Expand Up @@ -179,7 +181,7 @@ from sklearn.inspection import DecisionBoundaryDisplay
f1 = feature_names[0]
f2 = feature_names[3]
clf = DecisionTreeClassifier(max_depth=2)
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X_train[[f1, f2]], y_train)
d = DecisionBoundaryDisplay.from_estimator(clf, X_train[[f1, f2]])
Expand All @@ -204,29 +206,30 @@ max_depths = [1, 2, 3, 4, 5]
accuracy = []
for i, d in enumerate(max_depths):
clf = DecisionTreeClassifier(max_depth=d)
clf = DecisionTreeClassifier(max_depth=d, random_state=0)
clf.fit(X_train, y_train)
acc = clf.score(X_test, y_test)
accuracy.append((d, acc))
acc_df = pd.DataFrame(accuracy, columns=['depth', 'accuracy'])
sns.lineplot(acc_df, x='depth', y='accuracy')
sns.lineplot(acc_df, x='depth', y='accuracy', marker='o')
plt.xlabel('Tree depth')
plt.ylabel('Accuracy')
plt.show()
~~~
{: .language-python}


![Performance of decision trees of various depths](../fig/e3_dt_overfit.png)

Here we can see that a `max_depth=2` performs slightly better on the test data than those with `max_depth > 2`. This can seem counter intuitive, as surely more questions should be able to better split up our categories and thus give better predictions?

Let's reuse our fitting and plotting codes from above to inspect a decision tree that has `max_depth=5`:

~~~
clf = DecisionTreeClassifier(max_depth=5)
clf = DecisionTreeClassifier(max_depth=5, random_state=0)
clf.fit(X_train, y_train)
fig = plt.figure(figsize=(12, 10))
Expand All @@ -242,7 +245,7 @@ It looks like our decision tree has split up the training data into the correct
f1 = feature_names[0]
f2 = feature_names[3]
clf = DecisionTreeClassifier(max_depth=5)
clf = DecisionTreeClassifier(max_depth=5, random_state=0)
clf.fit(X_train[[f1, f2]], y_train)
d = DecisionBoundaryDisplay.from_estimator(clf, X_train[[f1, f2]])
Expand All @@ -258,7 +261,6 @@ Earlier we saw that the `max_depth=2` model split the data into 3 simple boundin

This is a classic case of over-fitting - our model has produced extremely specific parameters that work for the training data but are not representitive of our test data. Sometimes simplicity is better!


## Classification using support vector machines
Next, we'll look at another commonly used classification algorithm, and see how it compares. Support Vector Machines (SVM) work in a way that is conceptually similar to your own intuition when first looking at the data. They devise a set of hyperplanes that delineate the parameter space, such that each region contains ideally only observations from one class, and the boundaries fall between classes.

Expand Down Expand Up @@ -310,4 +312,4 @@ plt.show()

![Classification space generated by the SVM model](../fig/e3_svc_space.png)

While this SVM model performs slightly worse than our decision tree (95.6% vs. 98.5%), it's likely that the non-linear boundaries will perform better when exposed to more and more real data, as decision trees are prone to overfitting and requires complex linear models to reproduce simple non-linear boundaries. It's important to pick a model that is appropriate for your problem and data trends!
While this SVM model performs slightly worse than our decision tree (95.6% vs. 98.5%), it's likely that the non-linear boundaries will perform better when exposed to more and more real data, as decision trees are prone to overfitting and requires complex linear models to reproduce simple non-linear boundaries. It's important to pick a model that is appropriate for your problem and data trends!

0 comments on commit 268d20b

Please sign in to comment.