Predictive modeling with Machine Learning in R — Part 4 (Classification — Advanced)
“Data are becoming the new raw material of business.” ~Craig Mundie
This is the fourth post in the series Predictive modeling with Machine Learning (ML)in R. For earlier posts, please refer to below.
Predictive modeling with Machine Learning in R — Part 1(Introduction)
Predictive modeling with Machine Learning in R — Part 2 (Evaluation Metrics for Classification)
Predictive modeling with Machine Learning in R — Part 3 (Classification — Basic)
Introduction
In the last post, we learned how to run a machine learning model using default parameters. In this post, let’s go on step further by extracting features, tuning the parameters, performing cross-validation, and feature importance using Random Forest (RF) as our model of choice. There are quite a few packages like caret, ranger, randomForest, etc. that can run the RF model in R. Let’s stick to caret, since we have used it in the previous post. The results would be similar if you use other packages as well.
Dataset
In this post, we will be using the same synthetic dataset that we used in one of my previous posts (Exercise 2). It’s a mock-up dataset of a few hypothetical patients and their visits to hospital. The dataset is provided as 3 different tables (which is the typical case) obtained from a few sources.
- Patient demographics — consists of patients’ id, gender, ethnicity, total cost.
- Patient movement — patient id, Visit/admission date, type of visit, admission, and discharge date.
- Patient disease — patient id and whether the patients had any of the 9 diseases.
You can download the datasets here.
Objective
Our objective with these datasets would be to predict if the patients would be short-stayers (≤ 3days) or long-stayers (>3 days). This is just an objective I made up to illustrate how to apply RF model in R.
Here’s the entire code that you could just plug in and play.
Let’s go through one section at a time to better understand. Section 0 is just the loading of necessary libraries and the datasets.
STEP 1: Data Exploring + Feature Extraction + Dataset Curation
a. Exploring: From the code snippet, it is clear that all 3 datasets have a different number of unique patients. The demographic dataset is the base dataset to which we shall add more features (columns) that we think would be predictive of our outcome.
Demographics data has information about 2500 patients. Disease data has information for 875 of our patients, so we have to assume the data for the remaining patients are missing. The movement data has 4371 rows of data but comprises only 2069 unique patients. This means a number of patients have more than one row of information which we would need to aggregate to patient-level (meaning only one row per patient)
str(demo_data)
'data.frame': 2500 obs. of 5 variables:
$ i : int 1 2 3 4 5 6 7 8 9 10 ...
$ total.cost: num 29.1 417.1 856.1 261.9 755.3 ...
$ birth.year: int 1937 1972 1971 1972 1988 1973 1930 1945 1975 1964 ...
$ gender : chr "Female" "Male" "Male" "Male" ...
$ ethnicity : chr "Chinese" "Chinese" "Others" "Others" ...length(unique(demo_data$i))
[1] 2500length(unique(disease_data$i))
[1] 875length(unique(movement_data$i))
[1] 2069
b. FEATURE EXTRACTION — Adding new columns
Our movement data is missing an important column — Length of stay (LOS). Let’s curate this column by subtracting the discharge date and admission date. Please note we are using mutate function within the DPLYR package to add new columns. Two more columns are also added as indicator columns to be used while aggregating to the patient level.
movement_data = movement_data %>%
mutate(LOS = discharge.date - admission.date,|
ip_check = ifelse(type == "Inpatient", 1, 0),
op_check = ifelse(type == "Outpatient", 1,0))
c. FEATURE EXTRACTION — Aggregation and adding new columns
We shall use group_by and summarise functions to aggregate the movement data at the patient level. During this aggregation, we shall curate two new columns — the number of inpatient visits (n_ip_visits) and the number of outpatient visits (n_op_visits).
move_data = movement_data %>% group_by(i) %>%
summarise(n_ip_visits = sum(ip_check),
n_op_visits = sum(op_check))
d. Curating the working dataset
Next, let’s combine all three datasets into a single working dataset. For this, we will be using left_join function.
working_data = left_join(demo_data, disease_data, by = "i")
working_data = left_join(working_data, move_data, by = "i")
e. Changing LOS variable into a categorical variable
The LOS column, which is currently a numeric data type, needs to be converted into two categories — Short-stayers (≤3 days) and Long-stayers (>3 days). We call this new column target. Next, let’s remove some highly correlated variables from the working dataset.
working_data = working_data %>%
mutate(target = ifelse(n_ip_visits >=3,
"Long-stayers", "Short-stayers"))
working_data$target = as.factor(working_data$target)working_data = working_data %>%
select(- c(n_ip_visits, birth.year, Age, i))
STEP 2: Handling NAs
This is a perineal question for every data scientist/analyst — How to handle missing data! There is no one answer to this question. It would depend on the problem at hand, the quantity, and quality of data, what type of column it is. If you check our data, all of our NAs are in target and the disease columns. The target column is of factor data type, while the 9 disease columns are of numeric data type.
For the target column, let’s just remove all those rows that do not have a valid value. In essence, we are removing all those patients from the data that do not have an inpatient stay. This makes sense since our outcome is categorizing patients between short- and long-stayers. For the disease columns, let’s replace the NA with a third number — 2 indicating that these rows of patients have missing data for these disease columns. An alternative would be to replace the NA with the most occurring value of that column(mode) or impute with the mean of that column, etc.
# a. Target column
working_data = working_data %>% filter(!is.na(target))#b . Disease columns
working_data <- mutate_if(working_data, is.numeric, ~replace(., is.na(.), 2))
STEP 3. Data split and Cross-validation
Data split has already been explained in detail in my previous post. So, let;’s quickly dive into cross-validation. This step is essential to ensure our model train and predictions are robust. The idea is simple, we split our training dataset into k-folds of training and testing datasets. As seen in the illustration, in the 5-fold cross-validation, the validation set is keeps moving across the dataset. At the end of 5 folds, the performance is averaged across the 5-folds.
Performing cross-validation is pretty easy with the caret package. We use the trainControl function to mention how many folds of cross-validation do we need to perform. In this case, I’ve mentioned 5-fold cross-validation.
# Splitting the data into training and testing
data_split = createDataPartition(working_data$target, p = 0.8, list = FALSE)
training_data = working_data[ data_split,]
testing_data = working_data[-data_split,]# Separating Data and Labels
labels_train = training_data$target
training_data$target = NULL
labels_test = testing_data$target
testing_data$target = NULL# Cross validation
train_control<- trainControl(method="cv", number=5, savePredictions = TRUE, search = "random"
STEP 5. Model training and predictions
Again, you can refer to my previous post to know more about this step. I’m going to repeat it here. train_control obtained from the previous step is passed as an input to training the model, which would have optimized results of 5-folds cross-validation.
# Training using Random Forest within the caret package
sol = train(training_data,labels_train, trControl = train_control, ntree = 200, method = ‘rf’,tuneGrid = data.frame(.mtry = 30),importance = TRUE)# Predicting using the model and the testing data
pre = predict(sol, testing_data)
STEP 6. Performance metrics and Feature importance
We use standard metrics like accuracy, sensitivity, etc. to evaluate our model. For a detailed account of what each of these metrics means, please refer to my earlier post. This is a summary of metrics I obtained during my run — Accuracy (87.2%), Sensitivity (92.3%), Specificity (58.7%), and AUC (75.2%). The drop in specificity causes the AUC to be very low.
The reason for low Specificity is the imbalance of classes in our datasets. While we have 1749 short-stayers, we only have 320 long-stayers.
There are a few ways to handle such class imbalance.
- Resampling: The easiest way would be to reduce our short-stayers to 320 and match our long-stayers. This approach is called under-sampling. Another approach is to increase long-stayers to match short-stayers by collecting more data or artificially creating new data by adding random noise to existing patients’ data. This approach is called over-sampling
- Using algorithms like Extreme gradient boosting which can handle such class imbalance with aplomb.
#Performance metrics
metrics = confusionMatrix(labels_test, pre, positive = "Short-stayers")
sensitivity = metrics$byClass[1]
specificity = metrics$byClass[2]
accuracy = metrics$overall[1]
r.roc = roc(labels_test, as.numeric(pre))
r.roc$auc#Variable Importance
vi <- varImp(sol,scale=T)[[1]]
vi$var <-row.names(vi)
vi <- reshape2::melt(vi)ggplot(vi,aes(value,var,col=variable))+
geom_point()+
facet_wrap(~variable)
Finally, let’s plot variable importance plot to see which are the variables our RF model thinks are important in the prediction of LOS categories. For this, we can use the varImp function within the caret package and the use ggplot to plot the variable importance. As seen in the picture below, the total.cost variable is the most important variable to predict which LOS category a patient belongs to. This is obvious as more the LOS more would be the total cost of the stay.
Conclusions
Phew! That was a long post, but much needed to understand how to perform a well-rounded classification using machine learning involving all facets. Random forest was the algorithm we chose, but the framework we chose is extendable to any algorithm. Please replace the ‘rf’ in the Line 67 of our code, with other algorithms available in the caret package to compare the results.
In the next post, let’s shift focus towards the Regression problem using machine learning.