""" Example of using callbacks with Dask ==================================== """ import numpy as np import xgboost as xgb from xgboost.dask import DaskDMatrix from dask.distributed import Client from dask.distributed import LocalCluster from dask_ml.datasets import make_regression from dask_ml.model_selection import train_test_split def probability_for_going_backward(epoch): return 0.999 / (1.0 + 0.05 * np.log(1.0 + epoch)) # All callback functions must inherit from TrainingCallback class CustomEarlyStopping(xgb.callback.TrainingCallback): """A custom early stopping class where early stopping is determined stochastically. In the beginning, allow the metric to become worse with a probability of 0.999. As boosting progresses, the probability should be adjusted downward""" def __init__(self, *, validation_set, target_metric, maximize, seed): self.validation_set = validation_set self.target_metric = target_metric self.maximize = maximize self.seed = seed self.rng = np.random.default_rng(seed=seed) if maximize: self.better = lambda x, y: x > y else: self.better = lambda x, y: x < y def after_iteration(self, model, epoch, evals_log): metric_history = evals_log[self.validation_set][self.target_metric] if len(metric_history) < 2 or self.better( metric_history[-1], metric_history[-2] ): return False # continue training p = probability_for_going_backward(epoch) go_backward = self.rng.choice(2, size=(1,), replace=True, p=[1 - p, p]).astype( np.bool )[0] print( "The validation metric went into the wrong direction. " + f"Stopping training with probability {1 - p}..." ) if go_backward: return False # continue training else: return True # stop training def main(client): m = 100000 n = 100 X, y = make_regression(n_samples=m, n_features=n, chunks=200, random_state=0) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) dtrain = DaskDMatrix(client, X_train, y_train) dtest = DaskDMatrix(client, X_test, y_test) output = xgb.dask.train( client, { "verbosity": 1, "tree_method": "hist", "objective": "reg:squarederror", "eval_metric": "rmse", "max_depth": 6, "learning_rate": 1.0, }, dtrain, num_boost_round=1000, evals=[(dtrain, "train"), (dtest, "test")], callbacks=[ CustomEarlyStopping( validation_set="test", target_metric="rmse", maximize=False, seed=0 ) ], ) if __name__ == "__main__": # or use other clusters for scaling with LocalCluster(n_workers=4, threads_per_worker=1) as cluster: with Client(cluster) as client: main(client)