Part 1: Decision Trees
Decision trees and random forests are two commonly used algorithms in predictive modeling. In this article, I’m going to discuss the process behind decision trees. I’m planning to follow this up with a second part that discusses random forests, and then compare the two.
First off: decision trees. A decision tree is named for the shape of the plot that comes out. The image below shows a decision tree for deciding what factors affected survival from the Titanic disaster.
Before we go any further, I should introduce some terminology. Each branching point in the tree is called a node and represents a dataset that contains some or all of the records of the starting dataset. Going with the “tree” motif, the top (or starting) node is also called the root node. This contains all of the records (rows, individuals, whatever you want to call them) of the dataset (or at least all the records that you want to include). The tree grows from this root, giving us more nodes until it generates terminal nodes (i.e., those that aren’t split). These terminal nodes are called leaves. The tree above has four leaves. Each is labeled with the final prediction for that node.
There are two types of decision trees: classification and regression. A classification tree predicts the category of a categoric dependent variable — yes/no, apple/orange, died/survived, etc. A regression tree predicts the value of a numeric variable, similar to linear regression. The thing to watch out for with regression trees is that they can not extrapolate outside of the range of the training dataset like linear regression can. However, regression trees can use categoric input variables directly, unlike linear regression.
While the Titanic decision tree shows binary splits (each non-leaf node produces two child nodes), this is not a general requirement. Depending on the decision tree, nodes may have three or even more child nodes. I’m going to focus on classification decision trees for the rest of this article, but the basic idea is the same for regression trees as for classification trees.
Finally, I’ll mention that this discussion assumes the use of the rpart() function in R. I’ve heard that Python can’t handle categoric variables directly, but I’m much less familiar with Python, especially for data analysis. I believe that the basic theory is the same, but the implementation is different.
Decision trees are created in an iterative fashion. First, the variables are scanned to determine which gives the best split (more on this in a bit), then the dataset is split into smaller subsets based on that determination. Each subset is then analyzed again, with new subsets created, until the algorithm decides to stop. This decision is partly controlled by the parameters you set for the algorithm.
The splitting is based on how good the prediction is for the dataset (or subset) in question. The Titanic decision tree above initially created two subsets, based on gender, that have better predictive value. If we look at the Titanic example, icyousee.org reports that the overall survival rate was 32%. Looking at the decision tree above, we see that 73% of females survived. Summarizing the tree overall, we can generate the following rules:
- If a passenger was female, she most likely survived (73% chance)
- If a passenger was male, then survival depended on age and number of siblings aboard.
- Young boys with few siblings were very likely to survive (89% chance)
- Most males were out of luck
From here on out, I’m going to use an artificial dataset that I created, recording heights and weights for 10 adult men and 10 adult women. (Heights were randomly generated based on actual means and percentile rankss from a CDC study — see below. Weights were generated based on the height value plus height and weight standard deviations. See R code at the end.)
Let’s see if we can use this dataset to predict whether an individual is male or female. If we choose a person at random, we have a 50/50 chance of getting a male or female. With a decent predictive model, we should be able to do better. Males tend to be taller and weigh more, so maybe we can use one or both of those variables.
Decision trees use something called entropy to tell us how certain our prediction is. In physical terms, entropy refers to the amount of disorder in a system. The same is true here. If entropy is zero, there is no disorder. This only happens if we are absolutely certain that we know what we will get when we pick someone from our dataset. Don’t expect this to happen in real life — ever. We do, however, want to get the entropy as low as possible. The entropy of a dataset with two classes present is calculated by:
- S = entropy of dataset
- p1 = probability that an individual belongs to class 1
- p2 = probability that an individual belongs to class 2
With the original dataset, S = 1, which is the maximum disorder achievable with this equation. Try it out for yourself. Make sure you’re using log base 2. It’s not really vital to use base 2, but it’s the only way to get 1 as your maximum. (Note: if a particular dataset contains only one class, the second term is not present. If a dependent variable has more than two possible values, extra terms are added.)
When a dataset is split in a decision tree, the total entropy is calculated with the weighted average of the entropies of the subsets, as follows:
where (sorry, but Medium doesn’t allow subscripts in the text):
- S(T) = total entropy after split
- S(x) = entropy of subset x (i.e., S(1) = entropy of subset 1)
- f(x) = fraction of individuals that go to subset x
The total entropy at any point is calculated from all current leaf nodes (even if they may split later). In the Titanic tree from earlier, for example, the total entropy would have two (f S) terms after the first (gender) split, three terms after the second (age) split, and four terms after the third (sibling) split.
A decision tree algorithm will look at different values of different variables to determine which gives the best split, based on the reduction of entropy from the original value.
Let’s go through this with the height/weight dataset. Let’s start with weight as the dividing point. The mean female weight is 177 pounds, while the mean male weight is 201 pounds. Let’s split at the midpoint between those (189 pounds). This will give us two subsets:
- Subset 1 (< 189 lbs, predicted female) has 6 females and 3 males (9 total)
- Subset 2 (≥ 189 lbs, predicted male) has 4 females and 7 males (11 total)
From this info, we can calculate the entropy of the two subsets:
I’ll leave it to you, dear reader, to confirm the second entropy. To calculate the overall entropy of the split, we use Equation 2.
So by splitting on weight at 189 pounds, we get a slight reduction in entropy. A decision tree algorithm will check many values to see if they might give better results. It turns out that if we drop the split value down to 186 pounds, the total entropy drops to 0.88.
Of course, we can also try splitting based on height. Let’s start by going with the midpoint of the means again. That would give us a split value of 66.2 inches. From this split, we get the following two subsets:
- Subset 1 (< 66.2 in, predicted female) has 9 females and 2 males (11 total)
- Subset 2 (≥ 66.2 in, predicted male) has 1 female and 8 males (9 total)
With this information, we can calculate the entropy for this split:
This is much better than using weight for the split. But if we raise the split point to 68.2 inches, we can do even better, with a total entropy of 0.51. So now we have the first split for our decision tree. Next, the algorithm will check each of these subsets to see if they can be split again, based on weight this time. Doing this may get us better entropies, but the more we split, the more we risk overfitting the data. Remember that I’ve only sampled 20 individuals here, and some of those look like outliers (e.g., the tallest female weighs the least). With large datasets with many variables, you can get a true mess of a decision tree. There are several parameters that we can use to limit this, but that’s beyond the scope of what I wanted to go into here.
In a previous post, I talked about using Rattle to help in learning R. It turns out that the rattle package has a nice plotting function for decision trees. Here’s the image I get when I run the dataset through Rattle’s decision tree algorithm, with standard parameters (notice that it stopped after only one split):
This plot has many features to it, which I’ll go into right now.
- The color corresponds to the predicted gender (green = female, blue = male).
- Node 1 (aka the “root node”) has a 50/50 split of genders and accounts for 100% of the observations. Its green color means that if it had to predict the gender from this group, it would choose female.
- Below this node, we see the first split. If height < 68 inches, then we predict female. Otherwise, we predict male.
- Node 2 accounts for 65% of the dataset and has a 77/23 split of females to males. Its green color means that everybody in this subset is predicted to be female
- Node 3 accounts for 35% of the dataset and is 100% male, which is why it’s blue.
The biggest benefit of decision trees is the ease of understanding. They’re easy to read and it’s easy to see how the model made its prediction.
In part 2, we will investigate random forests, and see how they compare to decision trees.
- Titanic: Demographics of the Passengers
- Data Mining with Rattle and R: The Art of Excavating Data for Knowledge Discovery (Use R!)
(Note that the Amazon link above is an affiliate link.)
Creating the dataset
# Data from https://www.cdc.gov/nchs/data/series/sr_03/sr03-046-508.pdf
# means are directly retrieved from report
# SDs are estimated from 15th and 85th percentiles
# Weight of Females over 20 - Table 4 - excludes pregnant females
FWnum <- 5386 # number of females in sample
FWmean <- 170.8 # mean weight of females, in pounds
#15% = 126.9
#85% = 216.4
#diff / 2 = 44.75
FWSD <- 44 # estimated std dev, in pounds
# Weight of Males over 20 - Table 6
MWnum <- 5085 # number of males in sample
MWmean <- 199.8 # mean weight of males, in pounds
#15% = 154.2
#85% = 243.8
#diff / 2 = 44.8
MWSD <- 44 # estimated std dev, in pounds
# Height of Females over 20 - Table 10
FHnum <- 5510 # number of females in sample
FHmean <- 63.5 # mean height of females over 20, in inches
#15% = 60.6
#85% = 66.3
#diff / 2 = 2.85
FHSD <- 2.8 # estimated std dev, in pounds
# Height of Males over 20 - Table 12
MHnum <- 5092 # number of females in sample
MHmean <- 69.0 # mean height of females over 20, in inches
#15% = 66.0
#85% = 72.0
#diff / 2 = 3.0
MHSD <- 3 # estimated std dev, in pounds
# create 10 normally distributed female heights
FemaleHeight <- round(rnorm(10, mean = FHmean, sd = FHSD), 1)
# Calculate weight based on comparison of height to mean height
FemWCorrel <- FemaleHeight/FHmean * FWmean
# throw in some random deviation based on weight SD
FemWAdj <- rnorm(10, sd = FWSD/2)
FemaleWeight <- round(FemWCorrel + FemWAdj, 0)
F <- data.frame(Height = FemaleHeight,
Weight = FemaleWeight,
Gender = "F")
# create 10 normally distributed male heights
MaleHeight <- round(rnorm(10, mean = MHmean, sd = MHSD), 1)
# Calculate weight based on comparison of height to mean height
MaleWCorrel <- MaleHeight/MHmean * MWmean
# throw in some random deviation based on weight SD
MaleWAdj <- rnorm(10, sd = MWSD/2)
MaleWeight <- round((MaleWCorrel + MaleWAdj), 0)
M <- data.frame(Height = MaleHeight,
Weight = MaleWeight,
Gender = "M")
df <- rbind(F, M)