#' Parse a boosted tree model text dump #' #' Parse a boosted tree model text dump into a \code{data.table} structure. #' #' @param feature_names character vector of feature names. If the model already #' contains feature names, those would be used when \code{feature_names=NULL} (default value). #' Non-null \code{feature_names} could be provided to override those in the model. #' @param model object of class \code{xgb.Booster} #' @param text \code{character} vector previously generated by the \code{xgb.dump} #' function (where parameter \code{with_stats = TRUE} should have been set). #' \code{text} takes precedence over \code{model}. #' @param trees an integer vector of tree indices that should be parsed. #' If set to \code{NULL}, all trees of the model are parsed. #' It could be useful, e.g., in multiclass classification to get only #' the trees of one certain class. IMPORTANT: the tree index in xgboost models #' is zero-based (e.g., use \code{trees = 0:4} for first 5 trees). #' @param use_int_id a logical flag indicating whether nodes in columns "Yes", "No", "Missing" should be #' represented as integers (when FALSE) or as "Tree-Node" character strings (when FALSE). #' @param ... currently not used. #' #' @return #' A \code{data.table} with detailed information about model trees' nodes. #' #' The columns of the \code{data.table} are: #' #' \itemize{ #' \item \code{Tree}: integer ID of a tree in a model (zero-based index) #' \item \code{Node}: integer ID of a node in a tree (zero-based index) #' \item \code{ID}: character identifier of a node in a model (only when \code{use_int_id=FALSE}) #' \item \code{Feature}: for a branch node, it's a feature id or name (when available); #' for a leaf note, it simply labels it as \code{'Leaf'} #' \item \code{Split}: location of the split for a branch node (split condition is always "less than") #' \item \code{Yes}: ID of the next node when the split condition is met #' \item \code{No}: ID of the next node when the split condition is not met #' \item \code{Missing}: ID of the next node when branch value is missing #' \item \code{Quality}: either the split gain (change in loss) or the leaf value #' \item \code{Cover}: metric related to the number of observation either seen by a split #' or collected by a leaf during training. #' } #' #' When \code{use_int_id=FALSE}, columns "Yes", "No", and "Missing" point to model-wide node identifiers #' in the "ID" column. When \code{use_int_id=TRUE}, those columns point to node identifiers from #' the corresponding trees in the "Node" column. #' #' @examples #' # Basic use: #' #' data(agaricus.train, package='xgboost') #' #' bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label, max_depth = 2, #' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic") #' #' (dt <- xgb.model.dt.tree(colnames(agaricus.train$data), bst)) #' #' # This bst model already has feature_names stored with it, so those would be used when #' # feature_names is not set: #' (dt <- xgb.model.dt.tree(model = bst)) #' #' # How to match feature names of splits that are following a current 'Yes' branch: #' #' merge(dt, dt[, .(ID, Y.Feature=Feature)], by.x='Yes', by.y='ID', all.x=TRUE)[order(Tree,Node)] #' #' @export xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL, trees = NULL, use_int_id = FALSE, ...){ check.deprecation(...) if (!inherits(model, "xgb.Booster") && !is.character(text)) { stop("Either 'model' must be an object of class xgb.Booster\n", " or 'text' must be a character vector with the result of xgb.dump\n", " (or NULL if 'model' was provided).") } if (is.null(feature_names) && !is.null(model) && !is.null(model$feature_names)) feature_names <- model$feature_names if (!(is.null(feature_names) || is.character(feature_names))) { stop("feature_names: must be a character vector") } if (!(is.null(trees) || is.numeric(trees))) { stop("trees: must be a vector of integers.") } if (is.null(text)){ text <- xgb.dump(model = model, with_stats = TRUE) } if (length(text) < 2 || sum(grepl('leaf=(\\d+)', text)) < 1) { stop("Non-tree model detected! This function can only be used with tree models.") } position <- which(grepl("booster", text, fixed = TRUE)) add.tree.id <- function(node, tree) if (use_int_id) node else paste(tree, node, sep = "-") anynumber_regex <- "[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?" td <- data.table(t = text) td[position, Tree := 1L] td[, Tree := cumsum(ifelse(is.na(Tree), 0L, Tree)) - 1L] if (is.null(trees)) { trees <- 0:max(td$Tree) } else { trees <- trees[trees >= 0 & trees <= max(td$Tree)] } td <- td[Tree %in% trees & !grepl('^booster', t)] td[, Node := as.integer(sub("^([0-9]+):.*", "\\1", t))] if (!use_int_id) td[, ID := add.tree.id(Node, Tree)] td[, isLeaf := grepl("leaf", t, fixed = TRUE)] # parse branch lines branch_rx <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),", "gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")") branch_cols <- c("Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover") td[ isLeaf == FALSE, (branch_cols) := { matches <- regmatches(t, regexec(branch_rx, t)) # skip some indices with spurious capture groups from anynumber_regex xtr <- do.call(rbind, matches)[, c(2, 3, 5, 6, 7, 8, 10), drop = FALSE] xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree) if (length(xtr) == 0) { as.data.table( list(Feature = "NA", Split = "NA", Yes = "NA", No = "NA", Missing = "NA", Quality = "NA", Cover = "NA") ) } else { as.data.table(xtr) } } ] # assign feature_names when available is_stump <- function() { return(length(td$Feature) == 1 && is.na(td$Feature)) } if (!is.null(feature_names) && !is_stump()) { if (length(feature_names) <= max(as.numeric(td$Feature), na.rm = TRUE)) stop("feature_names has less elements than there are features used in the model") td[isLeaf == FALSE, Feature := feature_names[as.numeric(Feature) + 1]] } # parse leaf lines leaf_rx <- paste0("leaf=(", anynumber_regex, "),cover=(", anynumber_regex, ")") leaf_cols <- c("Feature", "Quality", "Cover") td[ isLeaf == TRUE, (leaf_cols) := { matches <- regmatches(t, regexec(leaf_rx, t)) xtr <- do.call(rbind, matches)[, c(2, 4)] if (length(xtr) == 2) { c("Leaf", as.data.table(xtr[1]), as.data.table(xtr[2])) } else { c("Leaf", as.data.table(xtr)) } } ] # convert some columns to numeric numeric_cols <- c("Split", "Quality", "Cover") td[, (numeric_cols) := lapply(.SD, as.numeric), .SDcols = numeric_cols] if (use_int_id) { int_cols <- c("Yes", "No", "Missing") td[, (int_cols) := lapply(.SD, as.integer), .SDcols = int_cols] } td[, t := NULL] td[, isLeaf := NULL] td[order(Tree, Node)] } # Avoid error messages during CRAN check. # The reason is that these variables are never declared # They are mainly column names inferred by Data.table... globalVariables(c("Tree", "Node", "ID", "Feature", "t", "isLeaf", ".SD", ".SDcols"))