Random Forest model

Highlights & Limitations

How it works

Here is a simple randomForest() model using the iris dataset:

library(randomForest)
model <- randomForest(Species ~ .,data = iris ,ntree = 100, proximity = TRUE)

The SQL translations returns a single SQL CASE WHEN operation. Each decision path is a WHEN statement.

library(tidypredict)

tidypredict_sql(model, dbplyr::simulate_mssql())
## <SQL> CASE
## WHEN (((`Petal.Length`) <= 2.5)) THEN ('setosa')
## WHEN ((((`Petal.Length`) > 5.05) AND ((`Petal.Length`) > 2.5))) THEN ('virginica')
## WHEN (((((`Petal.Width`) > 1.9) AND ((`Petal.Length`) > 2.5)) AND ((`Petal.Length`) <= 5.05))) THEN ('virginica')
## WHEN ((((((`Petal.Length`) > 2.5) AND ((`Sepal.Length`) <= 4.95)) AND ((`Petal.Width`) <= 1.9)) AND ((`Petal.Length`) <= 5.05))) THEN ('virginica')
## WHEN (((((((`Sepal.Length`) > 4.95) AND ((`Petal.Length`) > 2.5)) AND ((`Petal.Width`) <= 1.75)) AND ((`Petal.Width`) <= 1.9)) AND ((`Petal.Length`) <= 5.05))) THEN ('versicolor')
## WHEN ((((((((`Petal.Width`) > 1.75) AND ((`Sepal.Length`) > 4.95)) AND ((`Petal.Length`) > 2.5)) AND ((`Sepal.Width`) <= 3.0)) AND ((`Petal.Width`) <= 1.9)) AND ((`Petal.Length`) <= 5.05))) THEN ('virginica')
## WHEN ((((((((`Sepal.Width`) > 3.0) AND ((`Petal.Width`) > 1.75)) AND ((`Sepal.Length`) > 4.95)) AND ((`Petal.Length`) > 2.5)) AND ((`Petal.Width`) <= 1.9)) AND ((`Petal.Length`) <= 5.05))) THEN ('versicolor')
## END

Alternatively, use tidypredict_to_column() if the results are the be used or previewed in dplyr.

iris %>%
  tidypredict_to_column(model) %>%
  head(10)
##    Sepal.Length Sepal.Width Petal.Length Petal.Width Species    fit
## 1           5.1         3.5          1.4         0.2  setosa setosa
## 2           4.9         3.0          1.4         0.2  setosa setosa
## 3           4.7         3.2          1.3         0.2  setosa setosa
## 4           4.6         3.1          1.5         0.2  setosa setosa
## 5           5.0         3.6          1.4         0.2  setosa setosa
## 6           5.4         3.9          1.7         0.4  setosa setosa
## 7           4.6         3.4          1.4         0.3  setosa setosa
## 8           5.0         3.4          1.5         0.2  setosa setosa
## 9           4.4         2.9          1.4         0.2  setosa setosa
## 10          4.9         3.1          1.5         0.1  setosa setosa

Under the hood

The parser is based on the output from the randomForest::getTree() function. It will return as many decision paths as there are non-NA rows in the prediction field.

getTree(model, labelVar = TRUE) %>%
  head()
##   left daughter right daughter    split var split point status prediction
## 1             2              3 Petal.Length        2.50      1       <NA>
## 2             0              0         <NA>        0.00     -1     setosa
## 3             4              5 Petal.Length        5.05      1       <NA>
## 4             6              7  Petal.Width        1.90      1       <NA>
## 5             0              0         <NA>        0.00     -1  virginica
## 6             8              9 Sepal.Length        4.95      1       <NA>

The parsed model contains one row for each path. The field, operator and split_point field list every step in a concatenated character variable.

parse_model(model)
## # A tibble: 8 x 7
##   labels vals         type     estimate field           operator  split_p…
##   <chr>  <chr>        <chr>       <dbl> <chr>           <chr>     <chr>   
## 1 path-1 setosa       path            0 Petal.Length    left      2.5     
## 2 path-2 virginica    path            0 Petal.Length{:… right{:}… 5.05{:}…
## 3 path-3 virginica    path            0 Petal.Width{:}… right{:}… 1.9{:}5…
## 4 path-4 virginica    path            0 Sepal.Length{:… left{:}l… 4.95{:}…
## 5 path-5 versicolor   path            0 Petal.Width{:}… left{:}r… 1.75{:}…
## 6 path-6 virginica    path            0 Sepal.Width{:}… left{:}r… 3{:}1.7…
## 7 path-7 versicolor   path            0 Sepal.Width{:}… right{:}… 3{:}1.7…
## 8 model  randomForest variable       NA <NA>            <NA>      <NA>

The output from parse_model() is transformed into a dplyr, a.k.a Tidy Eval, formula. The entire decision tree becomes one dplyr::case_when() statement

tidypredict_fit(model)
## case_when(((Petal.Length) <= 2.5) ~ "setosa", (((Petal.Length) > 
##     5.05) & ((Petal.Length) > 2.5)) ~ "virginica", ((((Petal.Width) > 
##     1.9) & ((Petal.Length) > 2.5)) & ((Petal.Length) <= 5.05)) ~ 
##     "virginica", (((((Petal.Length) > 2.5) & ((Sepal.Length) <= 
##     4.95)) & ((Petal.Width) <= 1.9)) & ((Petal.Length) <= 5.05)) ~ 
##     "virginica", ((((((Sepal.Length) > 4.95) & ((Petal.Length) > 
##     2.5)) & ((Petal.Width) <= 1.75)) & ((Petal.Width) <= 1.9)) & 
##     ((Petal.Length) <= 5.05)) ~ "versicolor", (((((((Petal.Width) > 
##     1.75) & ((Sepal.Length) > 4.95)) & ((Petal.Length) > 2.5)) & 
##     ((Sepal.Width) <= 3)) & ((Petal.Width) <= 1.9)) & ((Petal.Length) <= 
##     5.05)) ~ "virginica", (((((((Sepal.Width) > 3) & ((Petal.Width) > 
##     1.75)) & ((Sepal.Length) > 4.95)) & ((Petal.Length) > 2.5)) & 
##     ((Petal.Width) <= 1.9)) & ((Petal.Length) <= 5.05)) ~ "versicolor")

From there, the Tidy Eval formula can be used anywhere where it can be operated. tidypredict provides three paths:

How it performs

Currently, the formula matches 147 out of 150 prediction of the test model. The threshold in tidypredict_test() is a integer indicating the number of records are OK to be different than the baseline prediction that the predict() function returns.

test <- tidypredict_test(model, iris, threshold = 5)

test
## tidypredict test results
## 
## Success, test is under the set threshold of: 5
## Predictions that did not match predict(): 3
test$raw_results %>%
  filter(predict != tidypredict)
##   Sepal.Length Sepal.Width Petal.Length Petal.Width    Species    predict
## 1          4.9         2.4          3.3         1.0 versicolor versicolor
## 2          6.0         2.7          5.1         1.6 versicolor versicolor
## 3          6.0         2.2          5.0         1.5  virginica  virginica
##   tidypredict
## 1   virginica
## 2   virginica
## 3  versicolor