Crates.io | eval-metrics |
lib.rs | eval-metrics |
version | 1.0.2 |
source | src |
created_at | 2021-05-11 23:29:28.761972 |
updated_at | 2022-06-06 16:39:17.657057 |
description | Evaluation metrics for machine learning |
homepage | https://github.com/benjarison/eval-metrics |
repository | https://github.com/benjarison/eval-metrics |
max_upload_size | |
id | 396351 |
size | 92,301 |
Evaluation metrics for machine learning
The goal of this library is to provide an intuitive collection of functions for computing evaluation metrics commonly
encountered in machine learning. Metrics are separated into modules for either classification
or regression
, with
the classification module supporting both binary and multi-class tasks. This distinction between binary and multi-class
classification is made explicit to underscore the fact that there are subtle differences in certain metrics between the
two cases (i.e. multi-class metrics often require averaging methods). Metrics can often fail to be defined for a
variety of numerical reasons, and in these cases Result
types are used to make this fact apparent.
Metric | Task | Description |
---|---|---|
Accuracy | Binary Classification | Binary Class Accuracy |
Precision | Binary Classification | Binary Class Precision |
Recall | Binary Classification | Binary Class Recall |
F-1 | Binary Classification | Harmonic Mean of Precision and Recall |
MCC | Binary Classification | Matthews Correlation Coefficient |
ROC Curve | Binary Classification | Receiver Operating Characteristic Curve |
AUC | Binary Classification | Area Under ROC Curve |
PR Curve | Binary Classification | Precision-Recall Curve |
AP | Binary Classification | Average Precision |
Accuracy | Multi-Class Classification | Multi-Class Accuracy |
Precision | Multi-Class Classification | Multi-Class Precision |
Recall | Multi-Class Classification | Multi-Class Recall |
F-1 | Multi-Class Classification | Multi-Class F1 |
Rk | Multi-Class Classification | K-Category Correlation Coefficient as described by Gorodkin (2004) |
M-AUC | Multi-Class Classification | Multi-Class AUC as described by Hand and Till (2001) |
RMSE | Regression | Root Mean Squared Error |
MSE | Regression | Mean Squared Error |
MAE | Regression | Mean Absolute Error |
R-Square | Regression | Coefficient of Determination |
Correlation | Regression | Linear Correlation Coefficient |
The BinaryConfusionMatrix
struct provides functionality for computing common binary classification metrics.
use eval_metrics::error::EvalError;
use eval_metrics::classification::BinaryConfusionMatrix;
fn main() -> Result<(), EvalError> {
// note: these scores could also be f32 values
let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
let labels = vec![false, false, true, false, true, false, false, true];
let threshold = 0.5;
// compute confusion matrix from scores and labels
let matrix = BinaryConfusionMatrix::compute(&scores, &labels, threshold)?;
// counts
let tpc = matrix.tp_count;
let fpc = matrix.fp_count;
let tnc = matrix.tn_count;
let fnc = matrix.fn_count;
// metrics
let acc = matrix.accuracy()?;
let pre = matrix.precision()?;
let rec = matrix.recall()?;
let f1 = matrix.f1()?;
let mcc = matrix.mcc()?;
// print matrix to console
println!("{}", matrix);
Ok(())
}
o=========================o
| Label |
o=========================o
| Positive | Negative |
o==============o============o============|============o
| | Positive | 2 | 2 |
| Prediction |============|------------|------------|
| | Negative | 1 | 3 |
o==============o============o=========================o
In addition to the metrics derived from the confusion matrix, ROC curves and PR curves can be computed, providing metrics such as AUC and AP.
use eval_metrics::error::EvalError;
use eval_metrics::classification::{RocCurve, RocPoint, PrCurve, PrPoint};
fn main() -> Result<(), EvalError> {
// note: these scores could also be f32 values
let scores = vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
let labels = vec![false, false, true, false, true, false, false, true];
// construct roc curve
let roc = RocCurve::compute(&scores, &labels)?;
// compute auc
let auc = roc.auc();
// inspect roc curve points
roc.points.iter().for_each(|point| {
let tpr = point.tp_rate;
let fpr = point.fp_rate;
let thresh = point.threshold;
});
// construct pr curve
let pr = PrCurve::compute(&scores, &labels)?;
// compute average precision
let ap = pr.ap();
// inspect pr curve points
pr.points.iter().for_each(|point| {
let pre = point.precision;
let rec = point.recall;
let thresh = point.threshold;
});
Ok(())
}
The MultiConfusionMatrix
struct provides functionality for computing common multi-class classification metrics.
Additionally, averaging methods must be explicitly provided for several of these metrics.
use eval_metrics::error::EvalError;
use eval_metrics::classification::{MultiConfusionMatrix, Averaging};
fn main() -> Result<(), EvalError> {
// note: these scores could also be f32 values
let scores = vec![
vec![0.3, 0.1, 0.6],
vec![0.5, 0.2, 0.3],
vec![0.2, 0.7, 0.1],
vec![0.3, 0.3, 0.4],
vec![0.5, 0.1, 0.4],
vec![0.8, 0.1, 0.1],
vec![0.3, 0.5, 0.2]
];
let labels = vec![2, 1, 1, 2, 0, 2, 0];
// compute confusion matrix from scores and labels
let matrix = MultiConfusionMatrix::compute(&scores, &labels)?;
// get counts
let counts = &matrix.counts;
// metrics
let acc = matrix.accuracy()?;
let mac_pre = matrix.precision(&Averaging::Macro)?;
let wgt_pre = matrix.precision(&Averaging::Weighted)?;
let mac_rec = matrix.recall(&Averaging::Macro)?;
let wgt_rec = matrix.recall(&Averaging::Weighted)?;
let mac_f1 = matrix.f1(&Averaging::Macro)?;
let wgt_f1 = matrix.f1(&Averaging::Weighted)?;
let rk = matrix.rk()?;
// print matrix to console
println!("{}", matrix);
Ok(())
}
o===================================o
| Label |
o===================================o
| Class-1 | Class-2 | Class-3 |
o==============o===========o===========|===========|===========o
| | Class-1 | 1 | 1 | 1 |
| |===========|-----------|-----------|-----------|
| Prediction | Class-2 | 1 | 1 | 0 |
| |===========|-----------|-----------|-----------|
| | Class-3 | 0 | 0 | 2 |
o==============o===========o===================================o
In addition to these global metrics, per-class metrics can be obtained as well.
use eval_metrics::error::EvalError;
use eval_metrics::classification::{MultiConfusionMatrix};
fn main() -> Result<(), EvalError> {
// note: these scores could also be f32 values
let scores = vec![
vec![0.3, 0.1, 0.6],
vec![0.5, 0.2, 0.3],
vec![0.2, 0.7, 0.1],
vec![0.3, 0.3, 0.4],
vec![0.5, 0.1, 0.4],
vec![0.8, 0.1, 0.1],
vec![0.3, 0.5, 0.2]
];
let labels = vec![2, 1, 1, 2, 0, 2, 0];
// compute confusion matrix from scores and labels
let matrix = MultiConfusionMatrix::compute(&scores, &labels)?;
// per-class metrics
let pca = matrix.per_class_accuracy();
let pcp = matrix.per_class_precision();
let pcr = matrix.per_class_recall();
let pcf = matrix.per_class_f1();
let pcm = matrix.per_class_mcc();
// print per-class metrics to console
println!("{:?}", pca);
println!("{:?}", pcp);
println!("{:?}", pcr);
println!("{:?}", pcf);
println!("{:?}", pcm);
Ok(())
}
[Ok(0.5714285714285714), Ok(0.7142857142857143), Ok(0.8571428571428571)]
[Ok(0.3333333333333333), Ok(0.5), Ok(1.0)]
[Ok(0.5), Ok(0.5), Ok(0.6666666666666666)]
[Ok(0.4), Ok(0.5), Ok(0.8)]
[Ok(0.09128709291752773), Ok(0.3), Ok(0.7302967433402215)]
In addition to the metrics derived from the confusion matrix, the M-AUC (multi-class AUC) metric as described by Hand and Till (2001) is provided as a standalone function:
let mauc = m_auc(&scores, &labels)?;
All regression metrics operate on a pair of scores and labels.
use eval_metrics::error::EvalError;
use eval_metrics::regression::*;
fn main() -> Result<(), EvalError> {
// note: these could also be f32 values
let scores = vec![0.4, 0.7, -1.2, 2.5, 0.3];
let labels = vec![0.2, 1.1, -0.9, 1.3, -0.2];
// root mean squared error
let rmse = rmse(&scores, &labels)?;
// mean squared error
let mse = mse(&scores, &labels)?;
// mean absolute error
let mae = mae(&scores, &labels)?;
// coefficient of determination
let rsq = rsq(&scores, &labels)?;
// pearson correlation coefficient
let corr = corr(&scores, &labels)?;
Ok(())
}