# ggplot backend for the xgboost plotting facilities #' @rdname xgb.plot.importance #' @export xgb.ggplot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL, rel_to_first = FALSE, n_clusters = c(1:10), ...) { importance_matrix <- xgb.plot.importance(importance_matrix, top_n = top_n, measure = measure, rel_to_first = rel_to_first, plot = FALSE, ...) if (!requireNamespace("ggplot2", quietly = TRUE)) { stop("ggplot2 package is required", call. = FALSE) } if (!requireNamespace("Ckmeans.1d.dp", quietly = TRUE)) { stop("Ckmeans.1d.dp package is required", call. = FALSE) } clusters <- suppressWarnings( Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix$Importance, n_clusters) ) importance_matrix[, Cluster := as.character(clusters$cluster)] plot <- ggplot2::ggplot(importance_matrix, ggplot2::aes(x = factor(Feature, levels = rev(Feature)), y = Importance, width = 0.5), environment = environment()) + ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position = "identity") + ggplot2::coord_flip() + ggplot2::xlab("Features") + ggplot2::ggtitle("Feature importance") + ggplot2::theme(plot.title = ggplot2::element_text(lineheight = .9, face = "bold"), panel.grid.major.y = ggplot2::element_blank()) return(plot) } #' @rdname xgb.plot.deepness #' @export xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight")) { if (!requireNamespace("ggplot2", quietly = TRUE)) stop("ggplot2 package is required for plotting the graph deepness.", call. = FALSE) which <- match.arg(which) dt_depths <- xgb.plot.deepness(model = model, plot = FALSE) dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth] setkey(dt_summaries, 'Depth') if (which == "2x1") { p1 <- ggplot2::ggplot(dt_summaries) + ggplot2::geom_bar(ggplot2::aes(x = Depth, y = N), stat = "Identity") + ggplot2::xlab("") + ggplot2::ylab("Number of leafs") + ggplot2::ggtitle("Model complexity") + ggplot2::theme( plot.title = ggplot2::element_text(lineheight = 0.9, face = "bold"), panel.grid.major.y = ggplot2::element_blank(), axis.ticks = ggplot2::element_blank(), axis.text.x = ggplot2::element_blank() ) p2 <- ggplot2::ggplot(dt_summaries) + ggplot2::geom_bar(ggplot2::aes(x = Depth, y = Cover), stat = "Identity") + ggplot2::xlab("Leaf depth") + ggplot2::ylab("Weighted cover") multiplot(p1, p2, cols = 1) return(invisible(list(p1, p2))) } else if (which == "max.depth") { p <- ggplot2::ggplot(dt_depths[, max(Depth), Tree]) + ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1), height = 0.15, alpha = 0.4, size = 3, stroke = 0) + ggplot2::xlab("tree #") + ggplot2::ylab("Max tree leaf depth") return(p) } else if (which == "med.depth") { p <- ggplot2::ggplot(dt_depths[, median(as.numeric(Depth)), Tree]) + ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1), height = 0.15, alpha = 0.4, size = 3, stroke = 0) + ggplot2::xlab("tree #") + ggplot2::ylab("Median tree leaf depth") return(p) } else if (which == "med.weight") { p <- ggplot2::ggplot(dt_depths[, median(abs(Weight)), Tree]) + ggplot2::geom_point(ggplot2::aes(x = Tree, y = V1), alpha = 0.4, size = 3, stroke = 0) + ggplot2::xlab("tree #") + ggplot2::ylab("Median absolute leaf weight") return(p) } } #' @rdname xgb.plot.shap.summary #' @export xgb.ggplot.shap.summary <- function(data, shap_contrib = NULL, features = NULL, top_n = 10, model = NULL, trees = NULL, target_class = NULL, approxcontrib = FALSE, subsample = NULL) { data_list <- xgb.shap.data( data = data, shap_contrib = shap_contrib, features = features, top_n = top_n, model = model, trees = trees, target_class = target_class, approxcontrib = approxcontrib, subsample = subsample, max_observations = 10000 # 10,000 samples per feature. ) p_data <- prepare.ggplot.shap.data(data_list, normalize = TRUE) # Reverse factor levels so that the first level is at the top of the plot p_data[, "feature" := factor(feature, rev(levels(feature)))] p <- ggplot2::ggplot(p_data, ggplot2::aes(x = feature, y = p_data$shap_value, colour = p_data$feature_value)) + ggplot2::geom_jitter(alpha = 0.5, width = 0.1) + ggplot2::scale_colour_viridis_c(limits = c(-3, 3), option = "plasma", direction = -1) + ggplot2::geom_abline(slope = 0, intercept = 0, colour = "darkgrey") + ggplot2::coord_flip() p } #' Combine and melt feature values and SHAP contributions for sample #' observations. #' #' Conforms to data format required for ggplot functions. #' #' Internal utility function. #' #' @param data_list List containing 'data' and 'shap_contrib' returned by #' \code{xgb.shap.data()}. #' @param normalize Whether to standardize feature values to have mean 0 and #' standard deviation 1 (useful for comparing multiple features on the same #' plot). Default \code{FALSE}. #' #' @return A data.table containing the observation ID, the feature name, the #' feature value (normalized if specified), and the SHAP contribution value. prepare.ggplot.shap.data <- function(data_list, normalize = FALSE) { data <- data_list[["data"]] shap_contrib <- data_list[["shap_contrib"]] data <- data.table::as.data.table(as.matrix(data)) if (normalize) { data[, (names(data)) := lapply(.SD, normalize)] } data[, "id" := seq_len(nrow(data))] data_m <- data.table::melt.data.table(data, id.vars = "id", variable.name = "feature", value.name = "feature_value") shap_contrib <- data.table::as.data.table(as.matrix(shap_contrib)) shap_contrib[, "id" := seq_len(nrow(shap_contrib))] shap_contrib_m <- data.table::melt.data.table(shap_contrib, id.vars = "id", variable.name = "feature", value.name = "shap_value") p_data <- data.table::merge.data.table(data_m, shap_contrib_m, by = c("id", "feature")) p_data } #' Scale feature value to have mean 0, standard deviation 1 #' #' This is used to compare multiple features on the same plot. #' Internal utility function #' #' @param x Numeric vector #' #' @return Numeric vector with mean 0 and sd 1. normalize <- function(x) { loc <- mean(x, na.rm = TRUE) scale <- stats::sd(x, na.rm = TRUE) (x - loc) / scale } # Plot multiple ggplot graph aligned by rows and columns. # ... the plots # cols number of columns # internal utility function multiplot <- function(..., cols = 1) { plots <- list(...) num_plots <- length(plots) layout <- matrix(seq(1, cols * ceiling(num_plots / cols)), ncol = cols, nrow = ceiling(num_plots / cols)) if (num_plots == 1) { print(plots[[1]]) } else { grid::grid.newpage() grid::pushViewport(grid::viewport(layout = grid::grid.layout(nrow(layout), ncol(layout)))) for (i in 1:num_plots) { # Get the i,j matrix positions of the regions that contain this subplot matchidx <- as.data.table(which(layout == i, arr.ind = TRUE)) print( plots[[i]], vp = grid::viewport( layout.pos.row = matchidx$row, layout.pos.col = matchidx$col ) ) } } } globalVariables(c( "Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme", "element_blank", "element_text", "V1", "Weight", "feature" ))