Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 110 additions & 110 deletions Cargo.lock

Large diffs are not rendered by default.

82 changes: 41 additions & 41 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ exclude = [
edition = "2024"
license = "MIT OR Apache-2.0"
readme = "README.md"
version = "0.21.0-pre.5"
version = "0.21.0"

[workspace.lints.clippy]

Expand Down Expand Up @@ -185,54 +185,54 @@ realfft = "3"
### Internal burn crates ###
# Declared here so each consumer Cargo.toml can use `workspace = true` instead of
# repeating the path and version.
burn = { path = "crates/burn", version = "=0.21.0-pre.5", default-features = false }
burn-autodiff = { path = "crates/burn-autodiff", version = "=0.21.0-pre.5", default-features = false }
burn-backend = { path = "crates/burn-backend", version = "=0.21.0-pre.5", default-features = false }
burn-candle = { path = "crates/burn-candle", version = "=0.21.0-pre.5", default-features = false }
burn-collective = { path = "crates/burn-collective", version = "=0.21.0-pre.5", default-features = false }
burn-communication = { path = "crates/burn-communication", version = "=0.21.0-pre.5", default-features = false }
burn-core = { path = "crates/burn-core", version = "=0.21.0-pre.5", default-features = false }
burn-cpu = { path = "crates/burn-cpu", version = "=0.21.0-pre.5", default-features = false }
burn-cubecl = { path = "crates/burn-cubecl", version = "=0.21.0-pre.5", default-features = false }
burn-cubecl-fusion = { path = "crates/burn-cubecl-fusion", version = "=0.21.0-pre.5", default-features = false }
burn-cuda = { path = "crates/burn-cuda", version = "=0.21.0-pre.5", default-features = false }
burn-dataset = { path = "crates/burn-dataset", version = "=0.21.0-pre.5", default-features = false }
burn-derive = { path = "crates/burn-derive", version = "=0.21.0-pre.5", default-features = false }
burn-dispatch = { path = "crates/burn-dispatch", version = "=0.21.0-pre.5", default-features = false }
burn-flex = { path = "crates/burn-flex", version = "=0.21.0-pre.5", default-features = false }
burn-fusion = { path = "crates/burn-fusion", version = "=0.21.0-pre.5", default-features = false }
burn-ir = { path = "crates/burn-ir", version = "=0.21.0-pre.5", default-features = false }
burn-ndarray = { path = "crates/burn-ndarray", version = "=0.21.0-pre.5", default-features = false }
burn-nn = { path = "crates/burn-nn", version = "=0.21.0-pre.5", default-features = false }
burn-optim = { path = "crates/burn-optim", version = "=0.21.0-pre.5", default-features = false }
burn-remote = { path = "crates/burn-remote", version = "=0.21.0-pre.5", default-features = false }
burn-rl = { path = "crates/burn-rl", version = "=0.21.0-pre.5", default-features = false }
burn-rocm = { path = "crates/burn-rocm", version = "=0.21.0-pre.5", default-features = false }
burn-router = { path = "crates/burn-router", version = "=0.21.0-pre.5", default-features = false }
burn-std = { path = "crates/burn-std", version = "=0.21.0-pre.5", default-features = false }
burn-store = { path = "crates/burn-store", version = "=0.21.0-pre.5", default-features = false }
burn-tch = { path = "crates/burn-tch", version = "=0.21.0-pre.5", default-features = false }
burn-tensor = { path = "crates/burn-tensor", version = "=0.21.0-pre.5", default-features = false }
burn-tensor-testgen = { path = "crates/burn-tensor-testgen", version = "=0.21.0-pre.5", default-features = false }
burn-train = { path = "crates/burn-train", version = "=0.21.0-pre.5", default-features = false }
burn-vision = { path = "crates/burn-vision", version = "=0.21.0-pre.5", default-features = false }
burn-wgpu = { path = "crates/burn-wgpu", version = "=0.21.0-pre.5", default-features = false }
burn = { path = "crates/burn", version = "0.21.0", default-features = false }
burn-autodiff = { path = "crates/burn-autodiff", version = "0.21.0", default-features = false }
burn-backend = { path = "crates/burn-backend", version = "0.21.0", default-features = false }
burn-candle = { path = "crates/burn-candle", version = "0.21.0", default-features = false }
burn-collective = { path = "crates/burn-collective", version = "0.21.0", default-features = false }
burn-communication = { path = "crates/burn-communication", version = "0.21.0", default-features = false }
burn-core = { path = "crates/burn-core", version = "0.21.0", default-features = false }
burn-cpu = { path = "crates/burn-cpu", version = "0.21.0", default-features = false }
burn-cubecl = { path = "crates/burn-cubecl", version = "0.21.0", default-features = false }
burn-cubecl-fusion = { path = "crates/burn-cubecl-fusion", version = "0.21.0", default-features = false }
burn-cuda = { path = "crates/burn-cuda", version = "0.21.0", default-features = false }
burn-dataset = { path = "crates/burn-dataset", version = "0.21.0", default-features = false }
burn-derive = { path = "crates/burn-derive", version = "0.21.0", default-features = false }
burn-dispatch = { path = "crates/burn-dispatch", version = "0.21.0", default-features = false }
burn-flex = { path = "crates/burn-flex", version = "0.21.0", default-features = false }
burn-fusion = { path = "crates/burn-fusion", version = "0.21.0", default-features = false }
burn-ir = { path = "crates/burn-ir", version = "0.21.0", default-features = false }
burn-ndarray = { path = "crates/burn-ndarray", version = "0.21.0", default-features = false }
burn-nn = { path = "crates/burn-nn", version = "0.21.0", default-features = false }
burn-optim = { path = "crates/burn-optim", version = "0.21.0", default-features = false }
burn-remote = { path = "crates/burn-remote", version = "0.21.0", default-features = false }
burn-rl = { path = "crates/burn-rl", version = "0.21.0", default-features = false }
burn-rocm = { path = "crates/burn-rocm", version = "0.21.0", default-features = false }
burn-router = { path = "crates/burn-router", version = "0.21.0", default-features = false }
burn-std = { path = "crates/burn-std", version = "0.21.0", default-features = false }
burn-store = { path = "crates/burn-store", version = "0.21.0", default-features = false }
burn-tch = { path = "crates/burn-tch", version = "0.21.0", default-features = false }
burn-tensor = { path = "crates/burn-tensor", version = "0.21.0", default-features = false }
burn-tensor-testgen = { path = "crates/burn-tensor-testgen", version = "0.21.0", default-features = false }
burn-train = { path = "crates/burn-train", version = "0.21.0", default-features = false }
burn-vision = { path = "crates/burn-vision", version = "0.21.0", default-features = false }
burn-wgpu = { path = "crates/burn-wgpu", version = "0.21.0", default-features = false }

### For the main burn branch. ###
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "f60142ccc35dcbede6db2c28a1315ae4cccabdd1" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "f60142ccc35dcbede6db2c28a1315ae4cccabdd1" }
# cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "f60142ccc35dcbede6db2c28a1315ae4cccabdd1" }
# cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "5e354fbaa20aa983698a2bfe4139b92f7d40023c" }
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7fd665837516e84ce306e2c96dbbccab6728c159" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7fd665837516e84ce306e2c96dbbccab6728c159" }
# cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7fd665837516e84ce306e2c96dbbccab6728c159" }
# cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "007f35272262c85d269b13761c8be334b03061bf" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
# cubecl-zspace = { path = "../cubecl/crates/cubecl-zspace", default-features = false }
# cubek = { path = "../cubek/crates/cubek", default-features = false }
### For the release. ###
cubecl = { version = "=0.10.0-pre.4", default-features = false }
cubecl-common = { version = "=0.10.0-pre.4", default-features = false }
cubecl-zspace = { version = "=0.10.0-pre.4", default-features = false }
cubek = { version = "=0.2.0-pre.5", default-features = false }
cubecl = { version = "0.10.0", default-features = false }
cubecl-common = { version = "0.10.0", default-features = false }
cubecl-zspace = { version = "0.10.0", default-features = false }
cubek = { version = "0.2.0", default-features = false }

[profile.dev]
debug = 1 # Speed up compilation time and not necessary.
43 changes: 0 additions & 43 deletions crates/burn-backend-tests/tests/cubecl/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,49 +230,6 @@ fn reduction_topk_3d_random_complex() {
actual.into_data().assert_eq(&expected.into_data(), false);
}

#[test]
fn test_topk_1d() {
// Int
let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]);

let values = tensor.topk(3, /*dim*/ 0);
let expected = TensorData::from([5, 4, 3]);

values.into_data().assert_eq(&expected, false);

// Float
let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]);

let values = tensor.topk(3, /*dim*/ 0);
let expected = TensorData::from([5., 4., 3.]);

values
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

#[test]
fn test_topk() {
// 3D Int
let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]);

let values = tensor.topk(2, /*dim*/ 2);
let expected = TensorData::from([[[7, 4], [6, 5]], [[9, 3], [8, 8]]]);

values.into_data().assert_eq(&expected, false);

// 3D Float
let tensor =
TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]);

let values = tensor.topk(2, /*dim*/ 2);
let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 8.]]]);

values
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

#[test]
fn reduction_argmin_should_match_reference_backend() {
let device = Default::default();
Expand Down
45 changes: 44 additions & 1 deletion crates/burn-backend-tests/tests/tensor/int/ops/topk.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use burn_tensor::TensorData;
use burn_tensor::{TensorData, Tolerance};

#[test]
fn test_topk_with_indices_1d() {
Expand All @@ -13,3 +13,46 @@ fn test_topk_with_indices_1d() {
let indices_expected = TensorData::from([4, 3, 2]);
indices.into_data().assert_eq(&indices_expected, false);
}

#[test]
fn test_topk_1d() {
// Int
let tensor = TestTensorInt::<1>::from([1, 2, 3, 4, 5]);

let values = tensor.topk(3, /*dim*/ 0);
let expected = TensorData::from([5, 4, 3]);

values.into_data().assert_eq(&expected, false);

// Float
let tensor = TestTensor::<1>::from([1., 2., 3., 4., 5.]);

let values = tensor.topk(3, /*dim*/ 0);
let expected = TensorData::from([5., 4., 3.]);

values
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}

#[test]
fn test_topk() {
// 3D Int
let tensor = TestTensorInt::<3>::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]);

let values = tensor.topk(2, /*dim*/ 2);
let expected = TensorData::from([[[7, 4], [6, 5]], [[9, 3], [8, 8]]]);

values.into_data().assert_eq(&expected, false);

// 3D Float
let tensor =
TestTensor::<3>::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]);

let values = tensor.topk(2, /*dim*/ 2);
let expected = TensorData::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 8.]]]);

values
.into_data()
.assert_approx_eq::<FloatElem>(&expected, Tolerance::default());
}
7 changes: 6 additions & 1 deletion crates/burn-backend/src/backend/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,12 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The values of the maximum elements along the dimension.
fn int_topk(tensor: IntTensor<B>, dim: usize, k: usize) -> IntTensor<B>;
fn int_topk(tensor: IntTensor<B>, dim: usize, k: usize) -> IntTensor<B> {
let device = Self::int_device(&tensor);
let dtype = get_device_settings::<B>(&device).int_dtype;
let k_indices = Self::int_arange(0..k as i64, &device, dtype);
Self::int_select(Self::int_sort(tensor, dim, true), dim, k_indices)
}

/// Gets the indices of the minimum elements along a dimension.
///
Expand Down
7 changes: 6 additions & 1 deletion crates/burn-backend/src/backend/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,12 @@ pub trait FloatTensorOps<B: Backend> {
/// # Returns
///
/// A tensor with the values of the maximum elements of `tensor` along `dim`.
fn float_topk(tensor: FloatTensor<B>, dim: usize, k: usize) -> FloatTensor<B>;
fn float_topk(tensor: FloatTensor<B>, dim: usize, k: usize) -> FloatTensor<B> {
let device = Self::float_device(&tensor);
let dtype = get_device_settings::<B>(&device).int_dtype;
let k_indices = B::int_arange(0..k as i64, &device, dtype);
Self::float_select(Self::float_sort(tensor, dim, true), dim, k_indices)
}

/// Gets the indices of the minimum elements of a tensor along an axis.
///
Expand Down
3 changes: 0 additions & 3 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,6 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
fn int_argtopk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
panic!("argtopk not implemented for candle backend")
}
fn int_topk(tensor: IntTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
panic!("topk not implemented for candle backend")
}

fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,6 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
panic!("argtopk not implemented for candle backend")
}

fn float_topk(tensor: FloatTensor<Self>, dim: usize, k: usize) -> IntTensor<Self> {
panic!("topk not implemented for candle backend")
}

fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> IntTensor<Self> {
CandleTensor::new(
tensor
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-flex/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,10 +924,6 @@ impl FloatTensorOps<Flex> for Flex {
unimplemented!("float_argtopk not implemented for flex")
}

fn float_topk(_tensor: FloatTensor<Flex>, _dim: usize, _k: usize) -> IntTensor<Flex> {
unimplemented!("float_topk not implemented for flex")
}

fn float_argmin(
tensor: FloatTensor<Flex>,
dim: usize,
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-flex/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,6 @@ impl IntTensorOps<Flex> for Flex {
panic!("argtopk not implemented for flex")
}

fn int_topk(_tensor: IntTensor<Flex>, _dim: usize, _k: usize) -> IntTensor<Flex> {
panic!("topk not implemented for flex")
}

fn int_argmin(tensor: IntTensor<Flex>, dim: usize) -> IntTensor<Flex> {
crate::ops::reduce::argmin(tensor, dim)
}
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-ndarray/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,6 @@ where
unimplemented!("argtopk not implemented for ndarray");
}

fn int_topk(_tensor: NdArrayTensor, _dim: usize, _k: usize) -> NdArrayTensor {
unimplemented!("topk not implemented for ndarray");
}

fn int_argmin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_int_dtype!(tensor, E, |array: SharedArray<E>| {
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,6 @@ where
unimplemented!("float_argtopk not implemented for ndarray")
}

fn float_topk(_tensor: FloatTensor<Self>, _dim: usize, _k: usize) -> NdArrayTensor {
unimplemented!("float_topk not implemented for ndarray")
}

fn float_argmin(tensor: FloatTensor<Self>, dim: usize, out_dtype: IntDType) -> NdArrayTensor {
// Use view() for zero-copy on borrowed storage
execute_with_int_out_dtype!(out_dtype, I, {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-store/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ burn-tch = { workspace = true, optional = true, features = ["default"] }
burn-wgpu = { workspace = true, optional = true, features = ["default"] }

[dev-dependencies]
# burn-import = { path = "../burn-import", version = "=0.21.0-pre.5" } # disabled (circular dep in publish, only for bench)
# burn-import = { path = "../burn-import", version = "0.21.0" } # disabled (circular dep in publish, only for bench)
burn-flex = { workspace = true, features = ["default"] }
burn-nn = { workspace = true }
divan = "0.1"
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,16 @@ impl TchOps {
)
}

pub fn topk(tensor: TchTensor, dim: usize, k: usize) -> TchTensor {
let (value, _indices) = tensor.tensor.topk(k as i64, dim as i64, true, true);
TchTensor::from_existing(value, tensor.storage)
}

pub fn argtopk(tensor: TchTensor, dim: usize, k: usize) -> TchTensor {
let (_value, indices) = tensor.tensor.topk(k as i64, dim as i64, true, true);
TchTensor::from_existing(indices, tensor.storage)
}

pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
TchTensor::from_existing(
tensor.tensor.cumsum(dim as i64, tensor.tensor.kind()),
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-tch/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
panic!("argtopk not implemented for torch")
}

fn int_topk(_tensor: TchTensor, _dim: usize, _k: usize) -> TchTensor {
panic!("topk not implemented for torch")
fn int_topk(tensor: TchTensor, dim: usize, k: usize) -> TchTensor {
TchOps::topk(tensor, dim, k)
}

fn int_argmin(tensor: TchTensor, dim: usize) -> TchTensor {
Expand Down
12 changes: 6 additions & 6 deletions crates/burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,16 +345,16 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
}

fn float_argtopk(
_tensor: TchTensor,
_dim: usize,
_k: usize,
tensor: TchTensor,
dim: usize,
k: usize,
_indices_dtype: IntDType,
) -> TchTensor {
unimplemented!("argtopk not implemented for Torch")
TchOps::argtopk(tensor, dim, k)
}

fn float_topk(_tensor: TchTensor, _dim: usize, _k: usize) -> TchTensor {
unimplemented!("topk not implemented for Torch")
fn float_topk(tensor: TchTensor, dim: usize, k: usize) -> TchTensor {
TchOps::topk(tensor, dim, k)
}

fn float_argmin(tensor: TchTensor, dim: usize, _out_dtype: IntDType) -> TchTensor {
Expand Down
3 changes: 0 additions & 3 deletions crates/burn-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,3 @@ pub use burn_backend::{
AllocationProperty, Bytes, DeviceSettings, StreamId, bf16, f16, get_device_settings, read_sync,
set_default_dtypes, try_read_sync,
};

// mod device;
// pub use device::*;
2 changes: 1 addition & 1 deletion examples/custom-csv-dataset/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
/// Panics if the download cannot be completed or the content of the file cannot be written to disk.
pub fn download_csv_if_missing() -> PathBuf {
// Point file to current example directory
let example_dir = Path::new(file!()).parent().unwrap().parent().unwrap();
let example_dir = Path::new(env!("CARGO_MANIFEST_DIR"));
let file_name = example_dir.join("diabetes.csv");

if file_name.exists() {
Expand Down
Binary file modified examples/mnist-inference-web/model.bin
Binary file not shown.
2 changes: 1 addition & 1 deletion examples/modern-lstm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
edition.workspace = true
name = "modern-lstm"
version = "0.21.0-pre.5"
version = "0.21.0"

[lints]
workspace = true
Expand Down
2 changes: 1 addition & 1 deletion examples/wgan/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "wgan"
version = "0.21.0-pre.5"
version = "0.21.0"
edition.workspace = true

[lints]
Expand Down
Loading