use float_cmp::ApproxEq;
use ndarray::arr2;
use pretty_assertions::{assert_eq, assert_ne};

use pyrus_nn::activations;

#[test]
fn test_tanh() {
    let x = arr2(&[[3.5]]);
    let result = activations::tanh(&x, false);
    assert_eq!(result.sum(), 0.99817795);
}

#[test]
fn test_sigmoid() {
    let x = arr2(&[[3.5]]);
    let result = activations::sigmoid(&x, false);
    assert_eq!(result.sum(), 0.97068775);
}

#[test]
fn test_softmax() {
    let x = arr2(&[[1., 2., 3.]]);
    let result = activations::softmax(&x, false);
    assert_eq!(result, arr2(&[[0.09003057, 0.24472848, 0.66524094]]));
}