by Bob Horton
Microsoft Senior Data Scientist
Learning curves are an elaboration of the idea of validating a model on a test set, and have been widely popularized by Andrew Ng’s Machine Learning course on Coursera. Here I present a simple simulation that illustrates this idea.
Imagine you use a sample of your data to train a model, then use the model to predict the outcomes on data where you know what the real outcome is. Since you know the “real” answer, you can calculate the overall error in your predictions. The error on the same data set used to train the model is called the training error, and the error on an independent sample is called the validation error.
A model will commonly perform better (that is, have lower error) on the data it was trained on than on an independent sample. The difference between the training error and the validation error reflects overfitting of the model. Overfitting is like memorizing the answers for a test instead of learning the principles (to borrow a metaphor from the Wikipedia article). Memorizing works fine if the test is exactly like the study guide, but it doesn’t work very well if the test questions are different; that is, it doesn’t generalize. In fact, the more a model is overfitted, the higher its validation error is likely to be. This is because the spurious correlations the overfitted model memorized from the training set most likely don’t apply in the validation set.
Overfitting is usually more extreme with small training sets. In large training sets the random noise tends to average out, so that the underlying patterns are more clear. But in small training sets, there is less opportunity for averaging out the noise, and accidental correlations consequently have more influence on the model. Learning curves let us visualize this relationship between training set size and the degree of overfitting.
We start with a function to generate simulated data:
sim_data <- function(N, noise_level=1){
X1 <- sample(LETTERS[1:10], N, replace=TRUE)
X2 <- sample(LETTERS[1:10], N, replace=TRUE)
X3 <- sample(LETTERS[1:10], N, replace=TRUE)
y <- 100 + ifelse(X1 == X2, 10, 0) + rnorm(N, sd=noise_level)
data.frame(X1, X2, X3, y)
}
The input columns X1, X2, and X3 are categorical variables which each have 10 possible values, represented by capital letters A
through J
. The outcome is cleverly named y
; it has a base level of 100, but if the values in the first two X
variables are equal, this is increased by 10. On top of this we add some normally distributed noise. Any other pattern that might appear in the data is accidental.
Now we can use this function to generate a simulated data set for experiments.
set.seed(123)
data <- sim_data(25000, noise=10)
There are many possible error functions, but I prefer the root mean squared error:
rmse <- function(actual, predicted) sqrt( mean( (actual - predicted)^2 ))
To generate a learning curve, we fit models at a series of different training set sizes, and calculate the training error and validation error for each model. Then we will plot these errors against the training set size. Here the parameters are a model formula, the data frame of simulated data, the validation set size (vss), the number of different training set sizes we want to plot, and the smallest training set size to start with. The largest training set will be all the rows of the dataset that are not used for validation.
run_learning_curve <- function(model_formula, data, vss=5000, num_tss=30, min_tss=1000){
library(data.table)
max_tss <- nrow(data) - vss
tss_vector <- seq(min_tss, max_tss, length=num_tss)
data.table::rbindlist( lapply (tss_vector, function(tss){
vs_idx <- sample(1:nrow(data), vss)
vs <- data[vs_idx,]
ts_eligible <- setdiff(1:nrow(data), vs_idx)
ts <- data[sample(ts_eligible, tss),]
fit <- lm( model_formula, ts)
training_error <- rmse(ts$y, predict(fit, ts))
validation_error <- rmse(vs$y, predict(fit, vs))
data.frame(tss=tss,
error_type = factor(c("training", "validation"),
levels=c("validation", "training")),
error=c(training_error, validation_error))
}) )
}
We’ll use a formula that considers all combinations of the input columns. Since these are categorical inputs, they will be represented by dummy variables in the model, with each combination of variable values getting its own coefficient.
learning_curve <- run_learning_curve(y ~ X1*X2*X3, data)
With this example, you get a series of warnings:
## Warning in predict.lm(fit, vs): prediction from a rank-deficient fit may be
## misleading
This is R trying to tell you that you don’t have enough rows to reliably fit all those coefficients. In this simulation, training set sizes above about 7500 don’t trigger the warning, though as we’ll see the curve still shows some evidence of overfitting.
library(ggplot2)
ggplot(learning_curve, aes(x=tss, y=error, linetype=error_type)) +
geom_line(size=1, col="blue") + xlab("training set size") + geom_hline(y=10, linetype=3)
In this figure, the X-axis represents different training set sizes and the Y-axis represents error. Validation error is shown in the solid blue line on the top part of the figure, and training error is shown by the dashed blue line in the bottom part. As the training set sizes get larger, these curves converge toward a level representing the amount of irreducible error in the data. This plot was generated using a simulated dataset where we know exactly what the irreducible error is; in this case it is the standard deviation of the Gaussian noise we added to the output in the simulation (10; the root mean squared error is essentially the same as standard deviation for reasonably large sample sizes). We don’t expect any model to reliably fit this error since we know it was completely random.
One interesting thing about this simulation is that the underlying system is very simple, yet it can take many thousands of training examples before the validation error of this model gets very close to optimum. In real life, you can easily encounter systems with many more variables, much higher cardinality, far more complex patterns, and of course lots and lots of those unpredictable variations we call “noise”. You can easily encounter situations where truly enormous numbers of samples are needed to train your model without excessive overfitting. On the other hand, if your training and validation error curves have already converged, more data may be superfluous. Learning curves can help you see if you are in a situation where more data is likely to be of benefit for training your model better.
I like to this article very much, but something is confusing me. Why is the training error getting bigger and the validation error getting smaller as the training set size increases. Shouldn't it be the opposite? Is it possible that there some reversal in the data or am I missing a concept?
Posted by: Seth Chandler | September 29, 2015 at 21:17
Seth, the training error is getting bigger because the model is now fitting to the underlying pattern, rather than to the random noise in the training set. Essentially, the model is now discriminating between the noise and the underlying pattern better.
Posted by: Patrick Rogers | September 30, 2015 at 08:36
Thanks for the helpful post - I also like learning curves, but have previously struggled to find a useful and standardised way to generate them. (I envy the function in scikit-learn.)
Let's hope that they appear in caret soon!
Posted by: Rob Penfold | October 05, 2015 at 07:42
Seth;
As Patrick said, the training error goes up with increasing training set size because the model becomes less overfitted. Since this is simulated data, we know what the error from an "optimal" model should be; this is shown by the dotted line representing the amount of random noise added to the data. Any model that seems to make predictions with error rates lower than this is fooling itself. The error rate for predictions made using an overfitted model on a validation set should be higher than optimal (on average - you sometimes get some jitter depending on how much the validation sample happens to match the biases of the overfitted model.)
Posted by: Bob Horton | October 05, 2015 at 18:27