{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "a22a5ebb-54fe-431f-bc8c-667d36f6f798", "metadata": { "execution": { "iopub.execute_input": "2023-03-06T20:47:04.654741Z", "iopub.status.busy": "2023-03-06T20:47:04.654422Z", "iopub.status.idle": "2023-03-06T20:47:04.659468Z", "shell.execute_reply": "2023-03-06T20:47:04.657869Z", "shell.execute_reply.started": "2023-03-06T20:47:04.654711Z" }, "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.metrics import r2_score, accuracy_score\n", "from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier\n", "from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n", "import rustrees.decision_tree as rt_dt\n", "import rustrees.random_forest as rt_rf\n", "import time\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "id": "7339d8b2-1b14-445c-8bf3-6ed0c050437c", "metadata": { "execution": { "iopub.execute_input": "2023-03-07T16:36:23.400487Z", "iopub.status.busy": "2023-03-07T16:36:23.400172Z", "iopub.status.idle": "2023-03-07T16:36:23.407520Z", "shell.execute_reply": "2023-03-07T16:36:23.406510Z", "shell.execute_reply.started": "2023-03-07T16:36:23.400459Z" }, "tags": [] }, "outputs": [], "source": [ "datasets = {\n", " \"reg\": [\"diabetes\", \"housing\", \"dgp\"],\n", " \"clf\": [\"breast_cancer\", \"titanic\"]\n", "}" ] }, { "cell_type": "code", "execution_count": 3, "id": "f8d469ba", "metadata": {}, "outputs": [], "source": [ "def evaluate_dataset(dataset, problem, model, max_depth, n_repeats, n_estimators=None):\n", " df_train = pd.read_csv(f\"../../datasets/{dataset}_train.csv\")\n", " df_test = pd.read_csv(f\"../../datasets/{dataset}_test.csv\")\n", "\n", " if problem == \"reg\":\n", " metric_fn = r2_score\n", " metric = \"r2\"\n", " if model == \"dt\":\n", " model_sk = DecisionTreeRegressor(max_depth=max_depth)\n", " model_rt = rt_dt.DecisionTreeRegressor(max_depth=max_depth)\n", " elif model == \"rf\":\n", " model_sk = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)\n", " model_rt = rt_rf.RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth)\n", " elif problem == \"clf\":\n", " metric_fn = accuracy_score\n", " metric = \"acc\"\n", " if model == \"dt\":\n", " model_sk = DecisionTreeClassifier(max_depth=max_depth)\n", " model_rt = rt_dt.DecisionTreeClassifier(max_depth=max_depth)\n", " elif model == \"rf\":\n", " model_sk = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)\n", " model_rt = rt_rf.RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)\n", "\n", " start_time = time.time()\n", " results_sk = []\n", " for _ in range(n_repeats):\n", " model_sk.fit(df_train.drop(\"target\", axis=1), df_train.target)\n", " results_sk.append(metric_fn(df_test.target, model_sk.predict(df_test.drop(\"target\", axis=1))))\n", " sk_time = (time.time() - start_time)/n_repeats\n", " sk_mean = np.mean(results_sk)\n", " sk_std = np.std(results_sk)\n", " \n", " start_time = time.time()\n", " results_rt = []\n", " for _ in range(n_repeats):\n", " model_rt.fit(df_train.drop(\"target\", axis=1), df_train.target)\n", " results_rt.append(metric_fn(df_test.target, model_rt.predict(df_test.drop(\"target\", axis=1))))\n", " rt_time = (time.time() - start_time)/n_repeats\n", " rt_mean = np.mean(results_rt)\n", " rt_std = np.std(results_rt)\n", " \n", " return (dataset, sk_mean, rt_mean, sk_std, rt_std, sk_time, rt_time, metric)" ] }, { "cell_type": "code", "execution_count": 4, "id": "8a2ae87c-9213-4c02-bd19-1a844eff5f05", "metadata": { "execution": { "iopub.execute_input": "2023-03-07T16:36:24.510409Z", "iopub.status.busy": "2023-03-07T16:36:24.510122Z", "iopub.status.idle": "2023-03-07T16:36:24.610884Z", "shell.execute_reply": "2023-03-07T16:36:24.610170Z", "shell.execute_reply.started": "2023-03-07T16:36:24.510384Z" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datasetsk_meanrt_meansk_stdrt_stdsk_time(s)rt_time(s)metric
0diabetes0.3153190.2700293.251468e-021.780794e-020.0026590.003520r2
1housing0.5997320.5983901.336886e-160.000000e+000.0429860.060472r2
2dgp0.9935090.9935104.440892e-160.000000e+000.0568520.360891r2
3breast_cancer0.9287020.9290186.747068e-036.746612e-030.0041650.006442acc
4titanic0.7864410.8067801.110223e-163.330669e-160.0023000.002896acc
\n", "
" ], "text/plain": [ " dataset sk_mean rt_mean sk_std rt_std sk_time(s) \\\n", "0 diabetes 0.315319 0.270029 3.251468e-02 1.780794e-02 0.002659 \n", "1 housing 0.599732 0.598390 1.336886e-16 0.000000e+00 0.042986 \n", "2 dgp 0.993509 0.993510 4.440892e-16 0.000000e+00 0.056852 \n", "3 breast_cancer 0.928702 0.929018 6.747068e-03 6.746612e-03 0.004165 \n", "4 titanic 0.786441 0.806780 1.110223e-16 3.330669e-16 0.002300 \n", "\n", " rt_time(s) metric \n", "0 0.003520 r2 \n", "1 0.060472 r2 \n", "2 0.360891 r2 \n", "3 0.006442 acc \n", "4 0.002896 acc " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_reg = [evaluate_dataset(d, \"reg\", model=\"dt\", max_depth=5, n_repeats=100) for d in datasets[\"reg\"]]\n", "results_clf = [evaluate_dataset(d, \"clf\", model=\"dt\", max_depth=5, n_repeats=100) for d in datasets[\"clf\"]]\n", "results = results_reg + results_clf\n", "\n", "cols = \"dataset sk_mean rt_mean sk_std rt_std sk_time(s) rt_time(s) metric\".split()\n", "\n", "pd.DataFrame(results, columns=cols)\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "75c713ea", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datasetsk_meanrt_meansk_stdrt_stdsk_time(s)rt_time(s)metric
0diabetes0.4379380.4328590.0093380.0057730.1145100.010676r2
1housing0.4396450.4405550.0006130.0008570.2555930.401618r2
2dgp0.7563770.7560610.0003420.0002760.3227762.913919r2
3breast_cancer0.9466670.9371930.0034380.0036630.1265190.025618acc
4titanic0.7633900.7728810.0049820.0000000.1403000.011944acc
\n", "
" ], "text/plain": [ " dataset sk_mean rt_mean sk_std rt_std sk_time(s) \\\n", "0 diabetes 0.437938 0.432859 0.009338 0.005773 0.114510 \n", "1 housing 0.439645 0.440555 0.000613 0.000857 0.255593 \n", "2 dgp 0.756377 0.756061 0.000342 0.000276 0.322776 \n", "3 breast_cancer 0.946667 0.937193 0.003438 0.003663 0.126519 \n", "4 titanic 0.763390 0.772881 0.004982 0.000000 0.140300 \n", "\n", " rt_time(s) metric \n", "0 0.010676 r2 \n", "1 0.401618 r2 \n", "2 2.913919 r2 \n", "3 0.025618 acc \n", "4 0.011944 acc " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_reg = [evaluate_dataset(d, \"reg\", model=\"rf\", max_depth=2, n_estimators=100, n_repeats=10) for d in datasets[\"reg\"]]\n", "results_clf = [evaluate_dataset(d, \"clf\", model=\"rf\", max_depth=2, n_estimators=100, n_repeats=10) for d in datasets[\"clf\"]]\n", "results = results_reg + results_clf\n", "\n", "cols = \"dataset sk_mean rt_mean sk_std rt_std sk_time(s) rt_time(s) metric\".split()\n", "\n", "pd.DataFrame(results, columns=cols)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b795007c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }