diff --git a/Cargo.toml b/Cargo.toml index 8b545fb6e..11240d9db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ intel-mkl-static = ["blas", "ndarray-linalg", "intel-mkl-src/mkl-static-lp64-seq intel-mkl-system = ["blas", "ndarray-linalg", "intel-mkl-src/mkl-dynamic-lp64-seq"] blas = ["ndarray/blas"] +polars = ["polars-core"] [dependencies] num-traits = "0.2" @@ -40,7 +41,7 @@ ndarray = { version = "0.13", default-features = false, features = ["approx"] } ndarray-linalg = { version = "0.12.1", optional = true } [dependencies.intel-mkl-src] -version = "0.6.0" +path = "/home/lorenz/Downloads/intel-mkl-src/intel-mkl-src" default-features = false optional = true @@ -56,6 +57,11 @@ optional = true default-features = false features = ["cblas"] +[dependencies.polars-core] +version = "0.12" +optional = true +default-features = false + [dev-dependencies] ndarray-rand = "0.11" approx = { version = "0.3", default-features = false, features = ["std"] } diff --git a/datasets/Cargo.toml b/datasets/Cargo.toml index c2d87cd1a..eeddb805d 100644 --- a/datasets/Cargo.toml +++ b/datasets/Cargo.toml @@ -13,6 +13,8 @@ ndarray = { version = "0.13", default-features = false } ndarray-csv = "0.4" csv = "1.1" flate2 = "1.0" +polars-core = { version = "0.12", default-features = false, optional = true } +polars-io = { version = "0.12", default-features = false, optional = true } [dev-dependencies] approx = { version = "0.3", default-features = false, features = ["std"] } @@ -23,3 +25,4 @@ diabetes = [] iris = [] winequality = [] linnerud = [] +income = ["polars-core", "polars-io", "linfa/polars"] diff --git a/datasets/data/income-test.tar.gz b/datasets/data/income-test.tar.gz new file mode 100644 index 000000000..920c233e0 Binary files /dev/null and b/datasets/data/income-test.tar.gz differ diff --git a/datasets/data/income-train.tar.gz b/datasets/data/income-train.tar.gz new file mode 100644 index 000000000..42aa424e4 Binary files /dev/null and b/datasets/data/income-train.tar.gz differ diff --git a/datasets/src/lib.rs b/datasets/src/lib.rs index 8b9b64551..02235923a 100644 --- a/datasets/src/lib.rs +++ b/datasets/src/lib.rs @@ -39,6 +39,15 @@ use linfa::Dataset; use ndarray::prelude::*; use ndarray_csv::Array2Reader; +#[cfg(feature = "income")] +use linfa::dataset::Dataframe; +#[cfg(feature = "income")] +use std::io::{Cursor, Read}; +#[cfg(feature = "income")] +use polars_core::frame::DataFrame; +#[cfg(feature = "income")] +use polars_io::{csv::CsvReader, SerReader}; + #[cfg(any( feature = "iris", feature = "diabetes", @@ -58,6 +67,20 @@ fn array_from_buf(buf: &[u8]) -> Array2 { reader.deserialize_array2_dynamic().unwrap() } +#[cfg(feature = "income")] +fn dataframe_from_buf(buf: &[u8]) -> DataFrame { + let mut file = GzDecoder::new(buf); + let mut buf: Vec = Vec::new(); + file.read_to_end(&mut buf).unwrap(); + let buf = Cursor::new(buf); + + CsvReader::new(buf) + .infer_schema(None) + .has_header(false) + .finish() + .unwrap() +} + #[cfg(feature = "iris")] /// Read in the iris-flower dataset from dataset path. // The `.csv` data is two dimensional: Axis(0) denotes y-axis (rows), Axis(1) denotes x-axis (columns) @@ -157,6 +180,14 @@ pub fn linnerud() -> Dataset { Dataset::new(input_array, output_array).with_feature_names(feature_names) } +#[cfg(feature = "income")] +pub fn income() -> (Dataframe, Dataframe) { + let input_data = include_bytes!("../data/income-train.tar.gz"); + let input_dataframe = dataframe_from_buf(&input_data[..]); + + panic!("") +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/dataset/impl_records.rs b/src/dataset/impl_records.rs index 49e9d62dd..97d74c5de 100644 --- a/src/dataset/impl_records.rs +++ b/src/dataset/impl_records.rs @@ -1,6 +1,9 @@ use super::{DatasetBase, Float, Records}; use ndarray::{ArrayBase, Axis, Data, Dimension}; +#[cfg(feature = "polars")] +use polars_core::frame::DataFrame; + /// Implement records for NdArrays impl, I: Dimension> Records for ArrayBase { type Elem = F; @@ -52,3 +55,16 @@ impl Records for &R { (*self).nfeatures() } } + +#[cfg(feature = "polars")] +impl Records for DataFrame { + type Elem = (); + + fn nsamples(&self) -> usize { + self.shape().0 + } + + fn nfeatures(&self) -> usize { + self.shape().1 + } +} diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index b4096c6fd..446dd0321 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -8,6 +8,9 @@ use ndarray::{ }; use num_traits::{FromPrimitive, NumAssignOps, Signed}; +#[cfg(feature = "polars")] +use polars_core::frame::DataFrame; + use std::cmp::{Ordering, PartialOrd}; use std::collections::{HashMap, HashSet}; use std::hash::Hash; @@ -147,6 +150,10 @@ pub type DatasetView<'a, D, T> = DatasetBase, ArrayView<'a pub type DatasetPr = DatasetBase, Ix2>, CountedTargets, Ix3>>>; +#[cfg(feature = "polars")] +/// Dataframe +pub type Dataframe = DatasetBase, Ix2>>; + /// Record trait pub trait Records: Sized { type Elem; diff --git a/src/prelude.rs b/src/prelude.rs index 030686626..f1769c91f 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -13,6 +13,9 @@ pub use crate::traits::*; #[doc(no_inline)] pub use crate::dataset::{AsTargets, Dataset, DatasetBase, DatasetView, Float, Records}; +#[cfg(feature = "polars")] +pub use crate::dataset::Dataframe; + #[doc(no_inline)] pub use crate::metrics_classification::{BinaryClassification, ConfusionMatrix, ToConfusionMatrix};