diff --git a/R/unify_ranger.R b/R/unify_ranger.R index 45c5380..a3d04a4 100644 --- a/R/unify_ranger.R +++ b/R/unify_ranger.R @@ -41,6 +41,10 @@ ranger.unify <- function(rf_model, data) { n <- rf_model$num.trees x <- lapply(1:n, function(tree) { tree_data <- data.table::as.data.table(ranger::treeInfo(rf_model, tree = tree)) + # Fix for probability forests + if (rf_model$treetype == "Probability estimation") { + data.table::setnames(tree_data, "pred.1", "prediction") + } tree_data[, c("nodeID", "leftChild", "rightChild", "splitvarName", "splitval", "prediction")] }) return(ranger_unify.common(x = x, n = n, data = data, feature_names = rf_model$forest$independent.variable.names))