by Joseph Rickert
In a recent previous post, I wrote about support vector machines, the representative master algorithm of the 5th tribe of machine learning practitioners described by Pedro Domingos in his book, The Master Algorithm. Here we look into algorithms favored by the first tribe, the symbolists, who see learning as the process of inverse deduction. Pedro writes:
Another limitation of inverse deduction is that it's very computational intensive, which makes it hard to scale to massive data sets. For these, the symbolist algorithm of choice is decision tree induction. Decision trees can be viewed as an answer to the question of what to do if rules of more than one concept match an instance. (p85)
The de facto standard for decision trees or “recursive partitioning” trees as they are known in the literature, is the CART algorithm by Breiman et al. (1984) implemented in R's rpart package. Stripped down to it’s essential structure, CART is a two stage algorithm. In the first stage, the algorithm conducts an exhaustive search over each variable to find the best split by maximizing an information criterion that will result in cells that are as pure as possible for one or the other of the class variables. In the second stage, a constant model is fit to each cell of the resulting partition. The algorithm then proceeds in a recursive “greedy” fashion making splits and not looking back to see how things might have been before making the next split. Although hugely successful in practice, the algorithm has two vexing problems: (1) overfitting and (2) selection bias – the algorithm favors features with many possible splits1. Overfitting occurs because the algorithm has “no concept of statistical significance” 2. While overfitting is usually handled with cross validation and pruning there doesn’t seem to be an easy way to deal with selection bias in the CART / rpart framework.
To address these issues Hothorn, Hornik and Zeileis introduced the party package into R about ten years ago which provides an implementation of conditional inference trees. (Unbiased Recursive Partitioning: A Conditional Inference Framework) Party’s ctree() function separates the selection of variables for splitting and the splitting process itself into two different steps and explicitly addresses bias selection by implementing statistical testing and a stopping procedure in the first step. Very roughly, the algorithm proceeds as follows:
- Each node of the tree is represented by a set of weights. Then, for each covariate vector X, the algorithm tests the null hypothesis that the dependent variable Y is independent of X. If the hypothesis cannot be rejected then the algorithm stops. Otherwise, the covariate with the strongest association with Y is selected for splitting.
- The algorithm performs a split and updates the weights describing the tree.
- Steps 1 and 2 are repeated recursively with the new parameter settings.
The details, along with enough theory to use the ctree algorithm with some confidence, are presented in this accessible vignette: “party: A Laboratory for Recursive Partitioning. The following example contrasts the ctree() and rpart() algorithms.
We begin by dividing the segmationData data set that comes with the caret package into training and test sets and fitting a ctree() model to it using the default parameters. No attempt is made to optimize the model. Next, we use the model to predict values of the Class variable on the test data set and calculate the area under the ROC curve to be 0.8326.
# Script to compare ctree with rpart library(party) library(rpart) library(caret) library(pROC) ### Get the Data # Load the data and construct indices to divide it into training and test data sets. data(segmentationData) # Load the segmentation data set data <- segmentationData[,3:61] data$Class <- ifelse(data$Class=="PS",1,0) # trainIndex <- createDataPartition(data$Class,p=.7,list=FALSE) trainData <- data[trainIndex,] testData <- data[-trainIndex,] #------------------------ set.seed(23) # Fit Conditional Tree Model ctree.fit <- ctree(Class ~ ., data=trainData) ctree.fit plot(ctree.fit,main="ctree Model") #Make predictions using the test data set ctree.pred <- predict(ctree.fit,testData) #Draw the ROC curve ctree.ROC <- roc(predictor=as.numeric(ctree.pred), response=testData$Class) ctree.ROC$auc #Area under the curve: 0.8326 plot(ctree.ROC,main="ctree ROC")
Here are the text and graphical descriptions of the resulting tree.
1) FiberWidthCh1 <= 9.887543; criterion = 1, statistic = 383.388 2) TotalIntenCh2 <= 42511; criterion = 1, statistic = 115.137 3) TotalIntenCh1 <= 39428; criterion = 1, statistic = 20.295 4)* weights = 504 3) TotalIntenCh1 > 39428 5)* weights = 9 2) TotalIntenCh2 > 42511 6) AvgIntenCh1 <= 199.2768; criterion = 1, statistic = 28.037 7) IntenCoocASMCh3 <= 0.5188792; criterion = 0.99, statistic = 14.022 8)* weights = 188 7) IntenCoocASMCh3 > 0.5188792 9)* weights = 7 6) AvgIntenCh1 > 199.2768 10)* weights = 36 1) FiberWidthCh1 > 9.887543 11) ShapeP2ACh1 <= 1.227156; criterion = 1, statistic = 48.226 12)* weights = 169 11) ShapeP2ACh1 > 1.227156 13) IntenCoocContrastCh3 <= 12.32349; criterion = 1, statistic = 22.349 14) SkewIntenCh4 <= 1.148388; criterion = 0.998, statistic = 16.78 15)* weights = 317 14) SkewIntenCh4 > 1.148388 16)* weights = 109 13) IntenCoocContrastCh3 > 12.32349 17) AvgIntenCh2 <= 244.9512; criterion = 0.999, statistic = 19.382 18)* weights = 53 17) AvgIntenCh2 > 244.9512 19)* weights = 22
Next, we fit an rpart() model to the training data using the default parameter settings and calculate the AUC to be 0.8536 on the test data.
# Fit CART Model rpart.fit <- rpart(Class ~ ., data=trainData,cp=0) rpart.fit plot(as.party(rpart.fit),main="rpart Model") #Make predictions using the test data set rpart.pred <- predict(rpart.fit,testData) #Draw the ROC curve rpart.ROC <- roc(predictor=as.numeric(rpart.pred), response=testData$Class) rpart.ROC$auc #Area under the curve: 0.8536 plot(rpart.ROC)
The resulting pruned tree does better than ctree(), but at the expense of building a slightly deeper tree.
1) root 1414 325.211500 0.64144270 2) TotalIntenCh2>=42606.5 792 191.635100 0.41035350 4) FiberWidthCh1>=11.19756 447 85.897090 0.25950780 8) ShapeP2ACh1< 1.225676 155 13.548390 0.09677419 * 9) ShapeP2ACh1>=1.225676 292 66.065070 0.34589040 18) SkewIntenCh4< 1.41772 254 53.259840 0.29921260 36) TotalIntenCh4< 127285.5 214 40.373830 0.25233640 72) EqEllipseOblateVolCh1>=383.1453 142 19.943660 0.16901410 * 73) EqEllipseOblateVolCh1< 383.1453 72 17.500000 0.41666670 146) AvgIntenCh1>=110.2253 40 6.400000 0.20000000 * 147) AvgIntenCh1< 110.2253 32 6.875000 0.68750000 * 37) TotalIntenCh4>=127285.5 40 9.900000 0.55000000 * 19) SkewIntenCh4>=1.41772 38 8.552632 0.65789470 * 5) FiberWidthCh1< 11.19756 345 82.388410 0.60579710 10) KurtIntenCh1< -0.3447192 121 28.000000 0.36363640 20) TotalIntenCh1>=13594 98 19.561220 0.27551020 * 21) TotalIntenCh1< 13594 23 4.434783 0.73913040 * 11) KurtIntenCh1>=-0.3447192 224 43.459820 0.73660710 22) AvgIntenCh1>=454.3329 7 0.000000 0.00000000 * 23) AvgIntenCh1< 454.3329 217 39.539170 0.76036870 46) VarIntenCh4< 130.9745 141 31.333330 0.66666670 92) NeighborAvgDistCh1>=256.5239 30 6.300000 0.30000000 * 93) NeighborAvgDistCh1< 256.5239 111 19.909910 0.76576580 * 47) VarIntenCh4>=130.9745 76 4.671053 0.93421050 * 3) TotalIntenCh2< 42606.5 622 37.427650 0.93569130 6) ShapeP2ACh1< 1.236261 11 2.545455 0.36363640 * 7) ShapeP2ACh1>=1.236261 611 31.217680 0.94599020 * >
Note, however, that complexity parameter for rpart(), cp, is set to zero rpart() builds a massive tree, a portion of which is shown below, and over fits the data yielding an AUC of 0.806
1) root 1414 325.2115000 0.64144270 2) TotalIntenCh2>=42606.5 792 191.6351000 0.41035350 4) FiberWidthCh1>=11.19756 447 85.8970900 0.25950780 8) ShapeP2ACh1< 1.225676 155 13.5483900 0.09677419 16) EntropyIntenCh1>=6.672119 133 7.5187970 0.06015038 32) AngleCh1< 108.6438 82 0.0000000 0.00000000 * 33) AngleCh1>=108.6438 51 6.7450980 0.15686270 66) EqEllipseLWRCh1>=1.184478 26 0.9615385 0.03846154 132) DiffIntenDensityCh3>=26.47004 19 0.0000000 0.00000000 * 133) DiffIntenDensityCh3< 26.47004 7 0.8571429 0.14285710 * 67) EqEllipseLWRCh1< 1.184478 25 5.0400000 0.28000000 134) IntenCoocContrastCh3>=9.637027 9 0.0000000 0.00000000 * 135) IntenCoocContrastCh3< 9.637027 16 3.9375000 0.43750000 * 17) EntropyIntenCh1< 6.672119 22 4.7727270 0.31818180 34) ShapeBFRCh1>=0.6778205 13 0.0000000 0.00000000 * 35) ShapeBFRCh1< 0.6778205 9 1.5555560 0.77777780 * 9) ShapeP2ACh1>=1.225676 292 66.0650700 0.34589040 18) SkewIntenCh4< 1.41772 254 53.2598400 0.29921260 36) TotalIntenCh4< 127285.5 214 40.3738300 0.25233640 72) EqEllipseOblateVolCh1>=383.1453 142 19.9436600 0.16901410 144) IntenCoocEntropyCh3< 7.059374 133 16.2857100 0.14285710 288) NeighborMinDistCh1>=21.91001 116 11.5431000 0.11206900 576) NeighborAvgDistCh1>=170.2248 108 8.2500000 0.08333333 1152) FiberAlign2Ch4< 1.481728 68 0.9852941 0.01470588 2304) XCentroid>=100.5 61 0.0000000 0.00000000 * 2305) XCentroid< 100.5 7 0.8571429 0.14285710 * 1153) FiberAlign2Ch4>=1.481728 40 6.4000000 0.20000000 2306) SkewIntenCh1< 0.9963465 27 1.8518520 0.07407407
In practice, rpart()'s complexity parameter (default value cp = .01) is effective in controlling tree growth and overfitting. It does, however, have an "ad hoc" feel to it. In contrast, the ctree() algorithm implements tests for statistical significance within the process of growing a decision tree. It automatically curtails excessive growth, inherently addresses both overfitting and bias and offers the promise of achieving good models with less computation.
Finally, note that rpart() and ctree() construct different trees that offer about the same performance. Some practitioners who value decision trees for their interpretability find this disconcerting. End users of machine learning models often want at story that tells them something true about their customer's behavior or buying preferences etc. But, the likelihood of there being multiple satisfactory answers to a complex problem is inherent to the process of inverse deduction. As Hothorn et al. comment:
Since a key reason for the popularity of tree based methods stems from their ability to represent the estimated regression relationship in an intuitive way, interpretations drawn from regression trees must be taken with a grain of salt.
1 Hothorn et al. (2006): Unbiased Recursive Partitioning: A Conditional Inference Framework J COMPUT GRAPH STAT Vol(15) No(3) Sept 2006
2. Mingers 1987: Expert Systems-Rule Induction with Statistical Data
as.party function not found, and I can find no information about it, except as a C function. How do we get the rpart.fit plot to include labels? I tried using the text.part function, but the output is too dense for readability.
Posted by: Thomas Keller | October 23, 2015 at 11:46
The as.party() function is in the partykit package. Sorry, I left library(partykit) out of the code.
Posted by: Joseph Rickert | October 23, 2015 at 14:29