Introduction
Last time we talked about non-linear classifier using Support Vector Machines or SVM. Today we'll be discussing another non-linear classifier and regressor called decision tree. The way decision tree works is by creating a model, which predicts the value of a target variable by learning simple decision rules inferred from the data features.
Since trees can be visualized and is something we're all used to, decision trees can easily be explained, visualized and manipulated the non-linearity in an intuitive manner. Surely there are some disadvantages as well, but we'll note them a bit later, firstly let's see them in action.
Implementation
It won't come as a complete surprise to you, that scikit package has already taken initiative and implemented the whole thing using DecisionTreeRegressor and DecisionTreeClassifier classes. What is left for us is to bear the fruits of someone else's hard labour.
import StringIO import numpy as np import matplotlib.pyplot as plt import pydot from IPython.display import Image from sklearn import tree from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier # Parameters n_classes = 3 plot_colors = "bry" plot_step = 0.02 plt.rcParams["figure.figsize"] = [12, 8] # Load data iris = load_iris() for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]): # We only take the two corresponding features X = iris.data[:, pair] y = iris.target # Shuffle idx = np.arange(X.shape[0]) np.random.seed(13) np.random.shuffle(idx) X = X[idx] y = y[idx] # Standardize mean = X.mean(axis=0) std = X.std(axis=0) X = (X - mean) / std # Train clf = DecisionTreeClassifier().fit(X, y) # Plot the decision boundary plt.subplot(2, 3, pairidx + 1) x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step)) Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) cs = plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) plt.xlabel(iris.feature_names[pair[0]]) plt.ylabel(iris.feature_names[pair[1]]) plt.axis() # Plot the training points for i, color in zip(range(n_classes), plot_colors): idx = np.where(y == i) plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i], cmap=plt.cm.Paired) plt.axis() plt.legend(loc="upper left") plt.show()
First we normalize the data and then draw decision boundaries and at last the data itself.
Take a look at the first subplot and let's compare it with one, where we used logistic regression. Back then we couldn't classify the data using these features, now using the decision tree we surely can.
Visualization
Since decision tree uses a tree data-structure, wouldn't it be cool to visualize it.
Notice: You'll need to install GraphViz package to run this example
... dot_data = StringIO.StringIO() tree.export_graphviz(clf, out_file=dot_data, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True, special_characters=True) graph = pydot.graph_from_dot_data(dot_data.getvalue()) Image(graph.create_png())
Now, when we have our tree in place, let's see how the decisions were made with Gini coefficient attached to each node. What Gini coefficient measures is the inequality among values of a frequency distribution, in our case iris species. A Gini coefficient of zero expresses perfect equality, where all values are the same - all iris are of the same species.
So the first check is against septal length being lesser than -0.7442, and from Gini coefficient 0.6667 we can deduce that it splits the data with one third going into a single category, setosa. The process continues until we reach Gini coefficient 0, that is the remaining data is of a single category.
Conclusion
One of the notable advantages of using decision trees is it's prediction performance, which is logarithmic in the number of data points used to train the tree. But there is are some downsides as well, which are needed to be considered as well. The first one is can be seen from our first subplot example - overfitting. Trees tend to perform incredibly well at the top, but at the same time tend to overfit at the bottom. This is a major downsite and therefore trees should be pruned! Also decision trees can be unstable because small variations or noise in the data might result in a completely different tree being generated. However this problem is mitigated by using decision trees within an ensemble, of which we'll talk next time.