In [1]:
import pandas as pd
from sklearn.metrics import r2_score, accuracy_score
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
import rustrees.decision_tree as rt_dt
import rustrees.random_forest as rt_rf
import time
import numpy as np

In [2]:
datasets = {
    "reg": ["diabetes", "housing", "dgp"],
    "clf": ["breast_cancer", "titanic"]
}

In [3]:
def evaluate_dataset(dataset, problem, model, max_depth, n_repeats, n_estimators=None):
    df_train = pd.read_csv(f"../../datasets/{dataset}_train.csv")
    df_test = pd.read_csv(f"../../datasets/{dataset}_test.csv")

    if problem == "reg":
        metric_fn = r2_score
        metric = "r2"
        if model == "dt":
            model_sk = DecisionTreeRegressor(max_depth=max_depth)
            model_rt = rt_dt.DecisionTreeRegressor(max_depth=max_depth)
        elif model == "rf":
            model_sk = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)
            model_rt = rt_rf.RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth)
    elif problem == "clf":
        metric_fn = accuracy_score
        metric = "acc"
        if model == "dt":
            model_sk = DecisionTreeClassifier(max_depth=max_depth)
            model_rt = rt_dt.DecisionTreeClassifier(max_depth=max_depth)
        elif model == "rf":
            model_sk = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)
            model_rt = rt_rf.RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)

    start_time = time.time()
    results_sk = []
    for _ in range(n_repeats):
        model_sk.fit(df_train.drop("target", axis=1), df_train.target)
        results_sk.append(metric_fn(df_test.target, model_sk.predict(df_test.drop("target", axis=1))))
    sk_time = (time.time() - start_time)/n_repeats
    sk_mean = np.mean(results_sk)
    sk_std = np.std(results_sk)
    
    start_time = time.time()
    results_rt = []
    for _ in range(n_repeats):
        model_rt.fit(df_train.drop("target", axis=1), df_train.target)
        results_rt.append(metric_fn(df_test.target, model_rt.predict(df_test.drop("target", axis=1))))
    rt_time = (time.time() - start_time)/n_repeats
    rt_mean = np.mean(results_rt)
    rt_std = np.std(results_rt)
        
    return (dataset, sk_mean, rt_mean, sk_std, rt_std, sk_time, rt_time, metric)

In [4]:
results_reg = [evaluate_dataset(d, "reg", model="dt", max_depth=5, n_repeats=100) for d in datasets["reg"]]
results_clf = [evaluate_dataset(d, "clf", model="dt", max_depth=5, n_repeats=100) for d in datasets["clf"]]
results = results_reg + results_clf

cols = "dataset sk_mean rt_mean sk_std rt_std sk_time(s) rt_time(s) metric".split()

pd.DataFrame(results, columns=cols)


Unnamed: 0,dataset,sk_mean,rt_mean,sk_std,rt_std,sk_time(s),rt_time(s),metric
0,diabetes,0.315319,0.270029,0.03251468,0.01780794,0.002659,0.00352,r2
1,housing,0.599732,0.59839,1.336886e-16,0.0,0.042986,0.060472,r2
2,dgp,0.993509,0.99351,4.440892e-16,0.0,0.056852,0.360891,r2
3,breast_cancer,0.928702,0.929018,0.006747068,0.006746612,0.004165,0.006442,acc
4,titanic,0.786441,0.80678,1.110223e-16,3.330669e-16,0.0023,0.002896,acc


In [5]:
results_reg = [evaluate_dataset(d, "reg", model="rf", max_depth=2, n_estimators=100, n_repeats=10) for d in datasets["reg"]]
results_clf = [evaluate_dataset(d, "clf", model="rf", max_depth=2, n_estimators=100, n_repeats=10) for d in datasets["clf"]]
results = results_reg + results_clf

cols = "dataset sk_mean rt_mean sk_std rt_std sk_time(s) rt_time(s) metric".split()

pd.DataFrame(results, columns=cols)


Unnamed: 0,dataset,sk_mean,rt_mean,sk_std,rt_std,sk_time(s),rt_time(s),metric
0,diabetes,0.437938,0.432859,0.009338,0.005773,0.11451,0.010676,r2
1,housing,0.439645,0.440555,0.000613,0.000857,0.255593,0.401618,r2
2,dgp,0.756377,0.756061,0.000342,0.000276,0.322776,2.913919,r2
3,breast_cancer,0.946667,0.937193,0.003438,0.003663,0.126519,0.025618,acc
4,titanic,0.76339,0.772881,0.004982,0.0,0.1403,0.011944,acc
