""" Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, using Optuna to tune hyperparameters """ from sklearn.model_selection import ShuffleSplit import pandas as pd import numpy as np import xgboost as xgb import optuna # The Veterans' Administration Lung Cancer Trial # The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980) df = pd.read_csv('../data/veterans_lung_cancer.csv') print('Training data:') print(df) # Split features and labels y_lower_bound = df['Survival_label_lower_bound'] y_upper_bound = df['Survival_label_upper_bound'] X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1) # Split data into training and validation sets rs = ShuffleSplit(n_splits=2, test_size=.7, random_state=0) train_index, valid_index = next(rs.split(X)) dtrain = xgb.DMatrix(X.values[train_index, :]) dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index]) dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index]) dvalid = xgb.DMatrix(X.values[valid_index, :]) dvalid.set_float_info('label_lower_bound', y_lower_bound[valid_index]) dvalid.set_float_info('label_upper_bound', y_upper_bound[valid_index]) # Define hyperparameter search space base_params = {'verbosity': 0, 'objective': 'survival:aft', 'eval_metric': 'aft-nloglik', 'tree_method': 'hist'} # Hyperparameters common to all trials def objective(trial): params = {'learning_rate': trial.suggest_loguniform('learning_rate', 0.01, 1.0), 'aft_loss_distribution': trial.suggest_categorical('aft_loss_distribution', ['normal', 'logistic', 'extreme']), 'aft_loss_distribution_scale': trial.suggest_loguniform('aft_loss_distribution_scale', 0.1, 10.0), 'max_depth': trial.suggest_int('max_depth', 3, 8), 'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0), 'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0)} # Search space params.update(base_params) pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik') bst = xgb.train(params, dtrain, num_boost_round=10000, evals=[(dtrain, 'train'), (dvalid, 'valid')], early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback]) if bst.best_iteration >= 25: return bst.best_score else: return np.inf # Reject models with < 25 trees # Run hyperparameter search study = optuna.create_study(direction='minimize') study.optimize(objective, n_trials=200) print('Completed hyperparameter tuning with best aft-nloglik = {}.'.format(study.best_trial.value)) params = {} params.update(base_params) params.update(study.best_trial.params) # Re-run training with the best hyperparameter combination print('Re-running the best trial... params = {}'.format(params)) bst = xgb.train(params, dtrain, num_boost_round=10000, evals=[(dtrain, 'train'), (dvalid, 'valid')], early_stopping_rounds=50) # Run prediction on the validation set df = pd.DataFrame({'Label (lower bound)': y_lower_bound[valid_index], 'Label (upper bound)': y_upper_bound[valid_index], 'Predicted label': bst.predict(dvalid)}) print(df) # Show only data points with right-censored labels print(df[np.isinf(df['Label (upper bound)'])]) # Save trained model bst.save_model('aft_best_model.json')