require(xgboost) require(jsonlite) context("Models from previous versions of XGBoost can be loaded") metadata <- list( kRounds = 2, kRows = 1000, kCols = 4, kForests = 2, kMaxDepth = 2, kClasses = 3 ) run_model_param_check <- function (config) { testthat::expect_equal(config$learner$learner_model_param$num_feature, '4') testthat::expect_equal(config$learner$learner_train_param$booster, 'gbtree') } get_num_tree <- function (booster) { dump <- xgb.dump(booster) m <- regexec('booster\\[[0-9]+\\]', dump, perl = TRUE) m <- regmatches(dump, m) num_tree <- Reduce('+', lapply(m, length)) return (num_tree) } run_booster_check <- function (booster, name) { # If given a handle, we need to call xgb.Booster.complete() prior to using xgb.config(). if (inherits(booster, "xgb.Booster") && xgboost:::is.null.handle(booster$handle)) { booster <- xgb.Booster.complete(booster) } config <- jsonlite::fromJSON(xgb.config(booster)) run_model_param_check(config) if (name == 'cls') { testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds * metadata$kClasses) testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5) testthat::expect_equal(config$learner$learner_train_param$objective, 'multi:softmax') testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), metadata$kClasses) } else if (name == 'logitraw') { testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds) testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0) testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logitraw') } else if (name == 'logit') { testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds) testthat::expect_equal(as.numeric(config$learner$learner_model_param$num_class), 0) testthat::expect_equal(config$learner$learner_train_param$objective, 'binary:logistic') } else if (name == 'ltr') { testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds) testthat::expect_equal(config$learner$learner_train_param$objective, 'rank:ndcg') } else { testthat::expect_equal(name, 'reg') testthat::expect_equal(get_num_tree(booster), metadata$kForests * metadata$kRounds) testthat::expect_equal(as.numeric(config$learner$learner_model_param$base_score), 0.5) testthat::expect_equal(config$learner$learner_train_param$objective, 'reg:squarederror') } } test_that("Models from previous versions of XGBoost can be loaded", { bucket <- 'xgboost-ci-jenkins-artifacts' region <- 'us-west-2' file_name <- 'xgboost_r_model_compatibility_test.zip' zipfile <- file.path(getwd(), file_name) model_dir <- file.path(getwd(), 'models') download.file(paste('https://', bucket, '.s3-', region, '.amazonaws.com/', file_name, sep = ''), destfile = zipfile, mode = 'wb', quiet = TRUE) unzip(zipfile, overwrite = TRUE) pred_data <- xgb.DMatrix(matrix(c(0, 0, 0, 0), nrow = 1, ncol = 4)) lapply(list.files(model_dir), function (x) { model_file <- file.path(model_dir, x) m <- regexec("xgboost-([0-9\\.]+)\\.([a-z]+)\\.[a-z]+", model_file, perl = TRUE) m <- regmatches(model_file, m)[[1]] model_xgb_ver <- m[2] name <- m[3] is_rds <- endsWith(model_file, '.rds') cpp_warning <- capture.output({ # Expect an R warning when a model is loaded from RDS and it was generated by version < 1.1.x if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) { booster <- readRDS(model_file) expect_warning(predict(booster, newdata = pred_data)) booster <- readRDS(model_file) expect_warning(run_booster_check(booster, name)) } else { if (is_rds) { booster <- readRDS(model_file) } else { booster <- xgb.load(model_file) } predict(booster, newdata = pred_data) run_booster_check(booster, name) } }) if (compareVersion(model_xgb_ver, '1.0.0.0') < 0) { # Expect a C++ warning when a model was generated in version < 1.0.x m <- grepl(paste0('.*Loading model from XGBoost < 1\\.0\\.0, consider saving it again for ', 'improved compatibility.*'), cpp_warning, perl = TRUE) expect_true(length(m) > 0 && all(m)) } else if (is_rds && model_xgb_ver == '1.1.1.1') { # Expect a C++ warning when a model is loaded from RDS and it was generated by version 1.1.x m <- grepl(paste0('.*Attempted to load internal configuration for a model file that was ', 'generated by a previous version of XGBoost.*'), cpp_warning, perl = TRUE) expect_true(length(m) > 0 && all(m)) } }) })