Predictive modeling with Machine Learning in R — Part 3 (Classification — Basic)
“Data is the new currency, and it’s the medium of exchange between consumers and marketers” — Lisa Utzschneider
This is the third post in the series Predictive modeling with Machine Learning (ML)in R. For earlier posts, please refer to below.
Predictive modeling with Machine Learning in R — Part 1(Introduction)
Predictive modeling with Machine Learning in R — Part 2 (Evaluation Metrics for Classification)
In this post, we shall learn about the following things.
- What is classification?
- Hands-on experience using a dataset
- What is classification?
Classification is a form of pattern recognition where the algorithm will use the train data to identify patterns in the data that leads to an outcome. Once it learns with sufficient confidence (or accuracy) it will try to predict the outcomes for future datasets.
One of the most common examples of classification that we see in our daily lives is the classification of an email into “spam” or “inbox”. In the picture on top, the classification model classifies a set of items into vegetables or groceries.
2. Hands-on experience using a dataset
Let’s use the standard “iris” dataset available in R to perform classification. For those who are not aware of this dataset, let me give a brief overview. This famous dataset provides variables like sepal length/width and petal length/width, measured in cm, for 50 observations each across 3 different species of flowers (setosa, versicolor and virginica). So, inputs are the measurements and the output (the one we need to predict) is the species.
In this example, we will be using a package called caret which encompasses most ML algorithms like linear model, random forest (RF), support vector machines (SVM), etc. This package also has handy tools to split the dataset and also to evaluate model performance.
Caveat: In this post, let’s do no-frills predictive modeling. This means, let’s not worry too much about data normalizing, feature extraction, or parameter tuning. We shall dive deep into those topics in later posts. Here, let’s get started with learning how to run ML models, with default parameters, on a dataset and evaluate them.
Here’s the entire code —copy-pasting this code in your R editor should work.
Let’s go through one section at a time and see what’s happening.
Section 0: Loading data and libraries
Caret and tidyverse are the two important libraries needed to run this script. If you do not have these packages, do install them using the install.packages command.
data(iris) loads the iris data into the R environment.
For a comprehensive list of built-in datasets in R, execute the command data() in your console. Check out other libraries like mlbench, sonar, and AppliedPredcitiveModeling as well for more datasets.
Section 1: Understanding the data
Understanding the data and domain expertise is the key to building a good predictive model using ML. So, let’s see a summary of our dataset and also how many classes of outputs we have in the data. The summary reveals that our dataset has 150 rows and 5 columns. Only the species column is factor while the other columns are numeric. Using the table command on our iris dataset reveals that each of the 3 output classes has an equal number of observations. In real life, this might not be the case and we have techniques like oversampling, undersampling to address this issue. We shall deal with this in later posts.
# 1. Understanding the data
# a. Let's look at a summary of the data
str(iris)'data.frame': 150 obs. of 5 variables:
$ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
$ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
$ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
$ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
$ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ..# b. How many samples do we have in each class?
table(iris$Species) setosa versicolor virginica
50 50 50
Let’s take a look visually at how the features (all columns except the output column — species) relate among themselves as well as with the output. For this, we shall make use of the handy featurePlot command within the AppliedPredictiveModeling package. As you can see in the code snippet below, this command plots a scatter plot between features(iris[,1:4]) and the output (iris$Species). The argument plot specifies the command to plot an ellipse around the data points. Other options that you can use instead of the ellipse are, box, scatter, density, pairs, etc. The output of this command is shown in the figure below the code snippet.
# c. How are the inputs related among themselves & how are they related to the output?
transparentTheme(trans = .4)
featurePlot(iris[,1:4], iris$Species, plot = "ellipse", auto.key = list(columns = 3))
Section 2: Splitting the data
In this step, first, we would be splitting the data into training and testing using the createDataPartition command within the caret package. Let’s look at this command in detail.
index = createDataPartition(iris$Species, p = 0.7, list = FALSE)
train_data = iris[index,]
test_data = iris[-index,]
This command accepts two critical arguments. The first is the column based on which we want to create a split. In our case, it’s the species argument. The second argument is the split you want — 0.7 for 70–30 split, 0.85 for 85–15 split, and so on. The output of this command — index will have a random 70% of rows of the iris dataset. This index variable is then used to create train and test datasets.
Next, we split the train_data and test_data into inputs (X) and outputs (y). As seen in the code snippet below, we use the select command within the dplyr package to select only the data(sans the outcome — Species) in X and only the outcome (Species) in the y.
X_train = train_data %>% select(-Species)
y_train = train_data %>% select(Species)
X_test = test_data %>% select(-Species)
y_test = test_data %>% select(Species)
Section 3: Training the model(s)
From the scatter plot matrix above, we get an indication that some features are (partially) linearly separable in certain dimensions. This means in general we would get good results with almost all algorithms. For illustration and comparison purposes let us build 5 models, which is a breeze with caret package. I’ve categorized these models into linear and non-linear models. Let’s look at the code and dive deep. If you take a closer look at the code snippet, the syntax for all the models remains the same except for the change in the method argument. I’ve used 5 different methods here- linear(LDA), decision tree (rpart), knn, support vector machines (svmRadial), random forest (rf). For a comprehensive list of methods available in caret, please refer here. In the code snippet, you would also notice a set.seed(7) command. This is to ensure that random number seed is reset everytime and use the same train-test splits across all models. This is to ensure reproducibility and comparability of results.
# 3. Training the model(s) on the training data
metric <- "Accuracy"
# a) linear algorithm
set.seed(7)
fit.lda <- train(X_train, y_train$Species, method="lda", metric=metric)# b) nonlinear algorithms
# CART
set.seed(7)
fit.cart <- train(X_train, y_train$Species, method="rpart", metric=metric)
# kNN
set.seed(7)
fit.knn <- train(X_train, y_train$Species, method="knn", metric=metric)# c) advanced algorithms
# SVM
set.seed(7)
fit.svm <- train(X_train, y_train$Species, method="svmRadial", metric=metric)
# Random Forest
set.seed(7)
fit.rf <- train(X_train, y_train$Species, method="rf", metric=metric)
Section 4: Assessing the performance of various models
Now, we have built 5 models for our dataset. Which model to choose? Since this is a simple version of the classification problem, let’s stick to Accuracy as our metric to choose the best model. For a detailed account of other metrics, do refer to my earlier post on this topic. Based on the mean accuracy, the LDA model clearly outperforms the others and we shall stick to it for making predictions on the test data.
# 4. Assessing the performance of various models
results <- resamples(list(lda=fit.lda, cart=fit.cart, knn=fit.knn, svm=fit.svm, rf=fit.rf))
summary(results)Accuracy
Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
lda 0.9024390 0.9268293 0.9473684 0.9519305 0.9729730 1.000000 0
cart 0.8292683 0.9090909 0.9189189 0.9234116 0.9459459 0.975000 0
knn 0.8823529 0.9268293 0.9473684 0.9461249 0.9722222 1.000000 0
svm 0.8529412 0.9069767 0.9189189 0.9228863 0.9250000 0.972973 0
rf 0.8750000 0.9189189 0.9230769 0.9290644 0.9444444 1.000000 0
Section 5 & 6: Making predictions using the best model & Evaluating the predictions
We run the LDA model (fit.lda) on the test dataset to obtain the predictions and then evaluate the predictions. confusionMatrix is a handy command that quickly displays the results in a matrix of classes x classes. Across each column, the true values are recorded, and across each row the predictions. Ideally, we want this matrix to be a diagonal matrix, meaning all correct predictions. Since the LDA model was able to achieve a diagonal matrix, the accuracy is 100%.
# 5. Making predictions out of the test data using the best model - Linear
predictions <- predict(fit.lda, X_test)# 6. Evaluating the prediction of Linear model
confusionMatrix(predictions, as.factor(y_test$Species))Confusion Matrix and StatisticsReference
Prediction setosa versicolor virginica
setosa 15 0 0
versicolor 0 15 0
virginica 0 0 15Overall Statistics
Accuracy : 1
95% CI : (0.9213, 1)
No Information Rate : 0.3333
P-Value [Acc > NIR] : < 2.2e-16
Conclusions
Congrats! That was our first ML model. I hope you had fun building it. Without going into details of each model or parameter tuning, we managed to walk through the typical steps involved in building a predictive model. The framework would remain the same for any algorithm or dataset, with additional steps of feature extraction or feature engineering for some datasets.
In the next post, let’s pick one algorithm (like RF) and dive deep into data normalizing, cross-validation, parameter tuning, and feature extraction on some other dataset. Till then try out the approach we discussed here on iris dataset with only 2 classes (exclude virginica).