//! LightGBM Dataset used for training use lightgbm3_sys::{DatasetHandle, C_API_DTYPE_FLOAT32, C_API_DTYPE_FLOAT64}; use std::os::raw::c_void; use std::{self, ffi::CString}; #[cfg(feature = "polars")] use polars::{datatypes::DataType::Float32, prelude::*}; use crate::{Error, Result}; // a way of implementing sealed traits until they // come to rust lang. more details at: // https://internals.rust-lang.org/t/sealed-traits/16797 mod private { pub trait Sealed {} impl Sealed for f32 {} impl Sealed for f64 {} } /// LightGBM dtype /// /// This trait is sealed as it is not intended /// to be implemented out of this crate pub trait DType: private::Sealed { fn get_c_api_dtype() -> i32; } impl DType for f32 { fn get_c_api_dtype() -> i32 { C_API_DTYPE_FLOAT32 as i32 } } impl DType for f64 { fn get_c_api_dtype() -> i32 { C_API_DTYPE_FLOAT64 as i32 } } /// LightGBM Dataset pub struct Dataset { pub(crate) handle: DatasetHandle, } impl Dataset { /// Creates a new Dataset object from the LightGBM's DatasetHandle. fn new(handle: DatasetHandle) -> Self { Self { handle } } /// Creates a new `Dataset` (x, labels) from flat `&[f64]` slice with a specified number /// of features (columns). /// /// `row_major` should be set to `true` for row-major order and `false` otherwise. /// /// # Example /// ``` /// use lightgbm3::Dataset; /// /// let x = vec![vec![1.0, 0.1, 0.2], /// vec![0.7, 0.4, 0.5], /// vec![0.9, 0.8, 0.5], /// vec![0.2, 0.2, 0.8], /// vec![0.1, 0.7, 1.0]]; /// let flat_x = x.into_iter().flatten().collect::>(); /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; /// let n_features = 3; /// let dataset = Dataset::from_slice(&flat_x, &label, n_features, true).unwrap(); /// ``` pub fn from_slice( flat_x: &[T], label: &[f32], n_features: i32, is_row_major: bool, ) -> Result { if n_features <= 0 { return Err(Error::new("number of features should be greater than 0")); } if flat_x.len() % n_features as usize != 0 { return Err(Error::new( "number of features doesn't correspond to slice size", )); } let n_rows = flat_x.len() / n_features as usize; if n_rows == 0 { return Err(Error::new("slice is empty")); } else if n_rows > i32::MAX as usize { return Err(Error::new(format!( "number of rows should be less than {}. Got {}", i32::MAX, n_rows ))); } let params = CString::new("").unwrap(); let label_str = CString::new("label").unwrap(); let reference = std::ptr::null_mut(); // not used let mut dataset_handle = std::ptr::null_mut(); // will point to a new DatasetHandle lgbm_call!(lightgbm3_sys::LGBM_DatasetCreateFromMat( flat_x.as_ptr() as *const c_void, T::get_c_api_dtype(), n_rows as i32, n_features, if is_row_major { 1_i32 } else { 0_i32 }, // is_row_major – 1 for row-major, 0 for column-major params.as_ptr(), reference, &mut dataset_handle ))?; lgbm_call!(lightgbm3_sys::LGBM_DatasetSetField( dataset_handle, label_str.as_ptr(), label.as_ptr() as *const c_void, n_rows as i32, C_API_DTYPE_FLOAT32 as i32 // labels should be always float32 ))?; Ok(Self::new(dataset_handle)) } /// Creates a new `Dataset` (x, labels) from `Vec>` in row-major order. /// /// # Example /// ``` /// use lightgbm3::Dataset; /// /// let data = vec![vec![1.0, 0.1, 0.2], /// vec![0.7, 0.4, 0.5], /// vec![0.9, 0.8, 0.5], /// vec![0.2, 0.2, 0.8], /// vec![0.1, 0.7, 1.0]]; /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; // should be Vec /// let dataset = Dataset::from_vec_of_vec(data, label, true).unwrap(); /// ``` pub fn from_vec_of_vec( x: Vec>, label: Vec, is_row_major: bool, ) -> Result { if x.is_empty() || x[0].is_empty() { return Err(Error::new("x is empty")); } let n_features = match is_row_major { true => x[0].len() as i32, false => x.len() as i32, }; let x_flat = x.into_iter().flatten().collect::>(); Self::from_slice(&x_flat, &label, n_features, is_row_major) } /// Create a new `Dataset` from tab-separated-view file. /// /// file is `tsv`. /// ```text ///