Skip to content

Commit 2ad5fa0

Browse files
authored
fix: update xgboost api and add depends R>=4.1.0 to description (#51)
* chore: added depends R>=4.1.0 to description to fix cran note * fix: update code to new v3 xgboost api * fix: updated more old xgboost api calls * test: fixed treeshap-correctness tests for xgboost * test: fixed xgboost unit-tests restored expected behaviour of 'predictions from unified == original predictions' by now controlling the 'base_score' parameter in order to prevent automated calculation of the intercept, as it was introduced to xgboost with v2 and later versions * test: fix issue with ambigious argument renamed argument 'model' -> 'model_name' to avoid issues with also used object named 'model' * fix: back to new split-condition logic for xgboost as split value is considered to be 'less than' * test: quite a hacky workaround to get original unit-tests working * chore: code formatting and removed library-import of treeshap at beginning of test-files * test: conditional testing of examples and unit-tests using suggested deps * chore: re-formatted examples for randomForest and fixed indentation for conditional testing * chore: replaced deprecated size argument of ggplots element_line * chore: some house-keeping renamed globals.R -> zzz.R and moved all import statements there
1 parent a1a5472 commit 2ad5fa0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2163
-1435
lines changed

DESCRIPTION

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: treeshap
22
Title: Compute SHAP Values for Your Tree-Based Models Using the 'TreeSHAP'
33
Algorithm
4-
Version: 0.3.1.9000
4+
Version: 0.3.2
55
Authors@R: c(
66
person("Konrad", "Komisarczyk", email = "komisarczykkonrad@gmail.com", role = "aut"),
77
person("Pawel", "Kozminski", email = "pkozminski99@gmail.com", role = "aut"),
@@ -21,7 +21,7 @@ URL: https://modeloriented.github.io/treeshap/,
2121
https://github.com/ModelOriented/treeshap
2222
BugReports: https://github.com/ModelOriented/treeshap/issues
2323
Depends:
24-
R (>= 2.10)
24+
R (>= 4.1.0)
2525
Imports:
2626
data.table,
2727
ggplot2,
@@ -41,4 +41,4 @@ LinkingTo:
4141
Encoding: UTF-8
4242
LazyData: true
4343
Roxygen: list(markdown = TRUE)
44-
RoxygenNote: 7.2.3
44+
RoxygenNote: 7.3.3

R/globals.R

Lines changed: 0 additions & 3 deletions
This file was deleted.

R/plot_contribution.R

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
#'
1818
#' @export
1919
#'
20-
#' @import ggplot2
21-
#'
2220
#' @seealso
2321
#' \code{\link{treeshap}} for calculation of SHAP values
2422
#'
@@ -27,17 +25,24 @@
2725
#'
2826
#' @examples
2927
#' \donttest{
30-
#' library(xgboost)
31-
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
32-
#' target <- fifa20$target
33-
#' param <- list(objective = "reg:squarederror", max_depth = 3)
34-
#' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target,
35-
#' nrounds = 20, verbose = FALSE)
36-
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
37-
#' x <- head(data, 1)
38-
#' shap <- treeshap(unified_model, x)
39-
#' plot_contribution(shap, 1, min_max = c(0, 120000000))
40-
#' }
28+
#' if (requireNamespace("xgboost", quietly = TRUE) &&
29+
#' requireNamespace("scales", quietly = TRUE)) {
30+
#' library(xgboost)
31+
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
32+
#' target <- fifa20$target
33+
#' xgb_model <- xgboost::xgboost(
34+
#' x = as.matrix(data),
35+
#' y = target,
36+
#' objective = "reg:squarederror",
37+
#' max_depth = 3,
38+
#' nrounds = 20
39+
#' )
40+
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
41+
#' x <- head(data, 1)
42+
#' shap <- treeshap(unified_model, x)
43+
#' plot_contribution(shap, 1, min_max = c(0, 120000000))
44+
#' }}
45+
#'
4146
plot_contribution <- function(treeshap,
4247
obs = 1,
4348
max_vars = 5,
@@ -46,6 +51,12 @@ plot_contribution <- function(treeshap,
4651
explain_deviation = FALSE,
4752
title = "SHAP Break-Down",
4853
subtitle = "") {
54+
if (!requireNamespace("scales", quietly = TRUE)) {
55+
stop(
56+
"Package \"scales\" needed for this function to work. Please install it.",
57+
call. = FALSE
58+
)
59+
}
4960

5061
shap <- treeshap$shaps[obs, ]
5162
model <- treeshap$unified_model$model

R/plot_feature_dependence.R

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
#'
1212
#' @export
1313
#'
14-
#' @import ggplot2
15-
#'
1614
#' @seealso
1715
#' \code{\link{treeshap}} for calculation of SHAP values
1816
#'
@@ -21,19 +19,31 @@
2119
#'
2220
#' @examples
2321
#' \donttest{
24-
#' library(xgboost)
25-
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
26-
#' target <- fifa20$target
27-
#' param <- list(objective = "reg:squarederror", max_depth = 3)
28-
#' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target,
29-
#' nrounds = 20, verbose = FALSE)
30-
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
31-
#' x <- head(data, 100)
32-
#' shaps <- treeshap(unified_model, x)
33-
#' plot_feature_dependence(shaps, variable = "overall")
34-
#' }
22+
#' if (requireNamespace("xgboost", quietly = TRUE) &&
23+
#' requireNamespace("scales", quietly = TRUE)) {
24+
#' library(xgboost)
25+
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
26+
#' target <- fifa20$target
27+
#' xgb_model <- xgboost::xgboost(
28+
#' x = as.matrix(data),
29+
#' y = target,
30+
#' objective = "reg:squarederror",
31+
#' max_depth = 3,
32+
#' nrounds = 20
33+
#' )
34+
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
35+
#' x <- head(data, 100)
36+
#' shaps <- treeshap(unified_model, x)
37+
#' plot_feature_dependence(shaps, variable = "overall")
38+
#' }}
3539
plot_feature_dependence <- function(treeshap, variable,
3640
title = "Feature Dependence", subtitle = NULL) {
41+
if (!requireNamespace("scales", quietly = TRUE)) {
42+
stop(
43+
"Package \"scales\" needed for this function to work. Please install it.",
44+
call. = FALSE
45+
)
46+
}
3747

3848
shaps <- treeshap$shaps
3949
x <- treeshap$observations

R/plot_feature_importance.R

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
#' @return a \code{ggplot2} object
1414
#'
1515
#' @export
16-
#' @import ggplot2
17-
#' @importFrom stats reorder
18-
#' @importFrom graphics text
1916
#'
2017
#' @seealso
2118
#' \code{\link{treeshap}} for calculation of SHAP values
@@ -25,21 +22,33 @@
2522
#'
2623
#' @examples
2724
#' \donttest{
28-
#' library(xgboost)
29-
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
30-
#' target <- fifa20$target
31-
#' param <- list(objective = "reg:squarederror", max_depth = 3)
32-
#' xgb_model <- xgboost::xgboost(as.matrix(data), params = param, label = target,
33-
#' nrounds = 20, verbose = FALSE)
34-
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
35-
#' shaps <- treeshap(unified_model, as.matrix(head(data, 3)))
36-
#' plot_feature_importance(shaps, max_vars = 4)
37-
#' }
25+
#' if (requireNamespace("xgboost", quietly = TRUE) &&
26+
#' requireNamespace("scales", quietly = TRUE)) {
27+
#' library(xgboost)
28+
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
29+
#' target <- fifa20$target
30+
#' xgb_model <- xgboost::xgboost(
31+
#' x = as.matrix(data),
32+
#' y = target,
33+
#' objective = "reg:squarederror",
34+
#' max_depth = 3,
35+
#' nrounds = 20
36+
#' )
37+
#' unified_model <- xgboost.unify(xgb_model, as.matrix(data))
38+
#' shaps <- treeshap(unified_model, as.matrix(head(data, 3)))
39+
#' plot_feature_importance(shaps, max_vars = 4)
40+
#' }}
3841
plot_feature_importance <- function(treeshap,
3942
desc_sorting = TRUE,
4043
max_vars = ncol(shaps),
4144
title = "Feature Importance",
4245
subtitle = NULL) {
46+
if (!requireNamespace("scales", quietly = TRUE)) {
47+
stop(
48+
"Package \"scales\" needed for this function to work. Please install it.",
49+
call. = FALSE
50+
)
51+
}
4352
shaps <- treeshap$shaps
4453

4554
# argument check

R/plot_interaction.R

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
#'
1414
#' @export
1515
#'
16-
#' @import ggplot2
17-
#'
1816
#' @seealso
1917
#' \code{\link{treeshap}} for calculation of SHAP Interaction values
2018
#'
@@ -23,17 +21,30 @@
2321
#'
2422
#' @examples
2523
#' \donttest{
26-
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
27-
#' target <- fifa20$target
28-
#' param2 <- list(objective = "reg:squarederror", max_depth = 5)
29-
#' xgb_model2 <- xgboost::xgboost(as.matrix(data), params = param2, label = target, nrounds = 10)
30-
#' unified_model2 <- xgboost.unify(xgb_model2, data)
31-
#' inters <- treeshap(unified_model2, as.matrix(data[1:50, ]), interactions = TRUE)
32-
#' plot_interaction(inters, "dribbling", "defending")
33-
#' }
24+
#' if (requireNamespace("xgboost", quietly = TRUE) &&
25+
#' requireNamespace("scales", quietly = TRUE)) {
26+
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
27+
#' target <- fifa20$target
28+
#' xgb_model2 <- xgboost::xgboost(
29+
#' x = as.matrix(data),
30+
#' y = target,
31+
#' objective = "reg:squarederror",
32+
#' max_depth = 5,
33+
#' nrounds = 10
34+
#' )
35+
#' unified_model2 <- xgboost.unify(xgb_model2, data)
36+
#' inters <- treeshap(unified_model2, as.matrix(data[1:50, ]), interactions = TRUE)
37+
#' plot_interaction(inters, "dribbling", "defending")
38+
#' }}
3439
plot_interaction <- function(treeshap, var1, var2,
3540
title = "SHAP Interaction Value Plot",
3641
subtitle = "") {
42+
if (!requireNamespace("scales", quietly = TRUE)) {
43+
stop(
44+
"Package \"scales\" needed for this function to work. Please install it.",
45+
call. = FALSE
46+
)
47+
}
3748

3849
interactions <- treeshap$interactions
3950
x <- treeshap$observations

R/predict.R

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,21 @@
1212
#'
1313
#' @examples
1414
#' \donttest{
15-
#' library(gbm)
16-
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
17-
#' data['value_eur'] <- fifa20$target
18-
#' gbm_model <- gbm::gbm(
19-
#' formula = value_eur ~ .,
20-
#' data = data,
21-
#' distribution = "laplace",
22-
#' n.trees = 20,
23-
#' interaction.depth = 4,
24-
#' n.cores = 1)
15+
#' if (requireNamespace("gbm", quietly = TRUE)) {
16+
#' library(gbm)
17+
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
18+
#' data['value_eur'] <- fifa20$target
19+
#' gbm_model <- gbm::gbm(
20+
#' formula = value_eur ~ .,
21+
#' data = data,
22+
#' distribution = "laplace",
23+
#' n.trees = 20,
24+
#' interaction.depth = 4,
25+
#' n.cores = 1
26+
#' )
2527
#' unified <- gbm.unify(gbm_model, data)
2628
#' predict(unified, data[2001:2005, ])
27-
#' }
29+
#' }}
2830
predict.model_unified <- function(object, x, ...) {
2931
unified_model <- object
3032
model <- unified_model$model

R/set_reference_dataset.R

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,20 @@
2727
#'
2828
#' @examples
2929
#' \donttest{
30-
#' library(gbm)
31-
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
32-
#' data['value_eur'] <- fifa20$target
33-
#' gbm_model <- gbm::gbm(
34-
#' formula = value_eur ~ .,
35-
#' data = data,
36-
#' distribution = "laplace",
37-
#' n.trees = 20,
38-
#' interaction.depth = 4,
39-
#' n.cores = 1)
40-
#' unified <- gbm.unify(gbm_model, data)
41-
#' set_reference_dataset(unified, data[200:700, ])
42-
#' }
30+
#' if (requireNamespace("gbm", quietly = TRUE)) {
31+
#' library(gbm)
32+
#' data <- fifa20$data[colnames(fifa20$data) != 'work_rate']
33+
#' data['value_eur'] <- fifa20$target
34+
#' gbm_model <- gbm::gbm(
35+
#' formula = value_eur ~ .,
36+
#' data = data,
37+
#' distribution = "laplace",
38+
#' n.trees = 20,
39+
#' interaction.depth = 4,
40+
#' n.cores = 1)
41+
#' unified <- gbm.unify(gbm_model, data)
42+
#' set_reference_dataset(unified, data[200:700, ])
43+
#' }}
4344
set_reference_dataset <- function(unified_model, x) {
4445
model <- unified_model$model
4546
data <- x

0 commit comments

Comments
 (0)