Skip to content

Commit fef87e7

Browse files
committed
simplify
1 parent e04ebd9 commit fef87e7

File tree

5 files changed

+125
-52
lines changed

5 files changed

+125
-52
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/lance-testing/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ arrow-schema = { workspace = true }
1616
rand = { workspace = true }
1717
num-traits = { workspace = true }
1818
lance-arrow = { workspace = true }
19+
tokio = { workspace = true, features = ["sync"] }
1920

2021
[lints]
2122
workspace = true

rust/lance-testing/src/allocator.rs

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
//! Simple memory allocation tracking for tests
55
//!
66
//! This module provides a global allocator wrapper that tracks memory usage statistics.
7-
//! Unlike the previous tracking-allocator implementation, this does not rely on tracing
8-
//! spans, ensuring all allocations are captured regardless of span propagation.
97
108
use std::alloc::{GlobalAlloc, Layout};
9+
use std::future::Future;
1110
use std::sync::atomic::{AtomicIsize, AtomicUsize, Ordering};
12-
use std::sync::Mutex;
1311

1412
/// Statistics about memory allocations
1513
#[derive(Debug, Clone, Copy, Default)]
@@ -143,6 +141,82 @@ unsafe impl<A: GlobalAlloc> GlobalAlloc for TrackingAllocator<A> {
143141
}
144142
}
145143

144+
/// Record allocation statistics while executing a function
145+
///
146+
/// This function automatically:
147+
/// - Acquires a lock to prevent concurrent tests from interfering
148+
/// - Resets allocation counters before execution
149+
/// - Executes the provided closure
150+
/// - Returns the allocation statistics
151+
///
152+
/// # Example
153+
///
154+
/// ```rust
155+
/// use lance_testing::allocator::record_allocations;
156+
///
157+
/// let stats = record_allocations(|| {
158+
/// // Allocate some memory
159+
/// let v: Vec<u8> = vec![0; 1024 * 1024]; // 1MB
160+
/// drop(v);
161+
/// });
162+
///
163+
/// assert!(stats.total_bytes_allocated >= 1024 * 1024);
164+
/// ```
165+
pub fn record_allocations<F, R>(f: F) -> AllocStats
166+
where
167+
F: FnOnce() -> R,
168+
{
169+
// Acquire lock to prevent concurrent tests from interfering
170+
let _guard = ALLOC_TEST_MUTEX.blocking_lock();
171+
172+
// Reset stats before execution
173+
GLOBAL_STATS.reset();
174+
175+
// Execute the closure
176+
let _ = f();
177+
178+
// Return the stats
179+
GLOBAL_STATS.get_stats()
180+
}
181+
182+
/// Record allocation statistics while executing a future
183+
///
184+
/// This function automatically:
185+
/// - Acquires a lock to prevent concurrent tests from interfering
186+
/// - Resets allocation counters before execution
187+
/// - Executes the provided closure
188+
/// - Returns the allocation statistics
189+
///
190+
/// # Example
191+
///
192+
/// ```rust
193+
/// use lance_testing::allocator::record_allocations_async;
194+
///
195+
/// let stats = record_allocations_async(async {
196+
/// // Allocate some memory
197+
/// let v: Vec<u8> = vec![0; 1024 * 1024]; // 1MB
198+
/// drop(v);
199+
/// });
200+
///
201+
/// assert!(stats.total_bytes_allocated >= 1024 * 1024);
202+
/// ```
203+
pub async fn record_allocations_async<F, R>(f: F) -> AllocStats
204+
where
205+
F: Future<Output = R>,
206+
{
207+
// Acquire lock to prevent concurrent tests from interfering
208+
let _guard = ALLOC_TEST_MUTEX.lock().await;
209+
210+
// Reset stats before execution
211+
GLOBAL_STATS.reset();
212+
213+
// Execute the closure
214+
let _ = f.await;
215+
216+
// Return the stats
217+
GLOBAL_STATS.get_stats()
218+
}
219+
146220
/// Get current allocation statistics
147221
///
148222
/// Returns statistics about memory allocations made by Lance's Rust code.
@@ -153,24 +227,22 @@ unsafe impl<A: GlobalAlloc> GlobalAlloc for TrackingAllocator<A> {
153227
/// - Allocations made by other native libraries (e.g., PyArrow, NumPy)
154228
/// - Memory-mapped files or shared memory
155229
///
156-
/// For comprehensive memory tracking in Python tests, use the
157-
/// `lance.testing.track_memory()` context manager which combines these stats
158-
/// with `tracemalloc` and PyArrow's memory pool statistics.
230+
/// **Note**: For tests, prefer using `record_allocations()` which automatically
231+
/// handles locking and reset.
159232
pub fn get_alloc_stats() -> AllocStats {
160233
GLOBAL_STATS.get_stats()
161234
}
162235

163236
/// Reset allocation statistics to zero
164237
///
165-
/// This resets all counters. Note that for accurate tracking in tests,
166-
/// you should use the `#[track_alloc]` macro which ensures single-threaded
167-
/// execution and automatic reset/collection.
238+
/// **Note**: For tests, prefer using `record_allocations()` which automatically
239+
/// handles locking and reset.
168240
pub fn reset_alloc_stats() {
169241
GLOBAL_STATS.reset();
170242
}
171243

172244
/// Mutex to ensure single-threaded execution of tests that track allocations
173245
///
174246
/// This prevents stats from different tests running in parallel from mixing.
175-
/// The `#[track_alloc]` macro uses this automatically.
176-
pub static ALLOC_TEST_MUTEX: Mutex<()> = Mutex::new(());
247+
/// The `record_allocations()` function uses this automatically.
248+
static ALLOC_TEST_MUTEX: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());

rust/lance/src/dataset/write/insert.rs

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,11 @@ struct WriteContext<'a> {
445445

446446
#[cfg(test)]
447447
mod test {
448+
use all_asserts::assert_gt;
448449
use arrow_array::StructArray;
449450
use arrow_schema::{DataType, Field, Schema};
451+
use lance_datafusion::datagen::DatafusionDatagenExt;
452+
use lance_datagen::{array, gen_batch, BatchCount, ByteCount, RoundingBehavior};
450453

451454
use crate::session::Session;
452455

@@ -502,57 +505,43 @@ mod test {
502505
async fn test_insert_memory_tracking() {
503506
// Test that insert doesn't load all data into memory at once
504507
// when writing 100MB of data in 10MB batches
505-
use arrow_array::Int32Array;
506-
use arrow_schema::ArrowError;
507-
508-
// Create 100MB of data in 10 batches of 10MB each
509-
let batch_size_mb = 10;
510-
let num_batches = 10;
511-
let total_rows_per_batch = (batch_size_mb * 1024 * 1024) / (4 * 3); // Int32 = 4 bytes, 3 columns
512-
513-
let schema = Arc::new(Schema::new(vec![
514-
Field::new("a", DataType::Int32, false),
515-
Field::new("b", DataType::Int32, false),
516-
Field::new("c", DataType::Int32, false),
517-
]));
518-
519-
// Create batches on demand
520-
let batches: Vec<std::result::Result<RecordBatch, ArrowError>> = (0..num_batches)
521-
.map(|i| {
522-
let start = i * total_rows_per_batch;
523-
let end = (i + 1) * total_rows_per_batch;
524-
let a = Int32Array::from_iter_values(start..end);
525-
let b = Int32Array::from_iter_values(start..end);
526-
let c = Int32Array::from_iter_values(start..end);
527-
RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b), Arc::new(c)])
528-
})
529-
.collect();
530-
531-
lance_testing::allocator::reset_alloc_stats();
508+
let batch_size = 10 * 1024 * 1024; // 10MB
509+
510+
let stats = lance_testing::allocator::record_allocations_async(async {
511+
let tmp_dir = tempfile::tempdir().unwrap();
512+
let tmp_path = tmp_dir.path().to_str().unwrap().to_string();
513+
// Create batches on demand
514+
// Create a stream of 100MB of data, in batches
515+
let num_batches = BatchCount::from(10);
516+
let data = gen_batch()
517+
.col("a", array::rand_type(&DataType::Int32))
518+
.into_df_stream_bytes(
519+
ByteCount::from(batch_size),
520+
num_batches,
521+
RoundingBehavior::RoundDown,
522+
)
523+
.unwrap();
532524

533-
let _dataset = InsertBuilder::new("memory://")
534-
.execute_stream(RecordBatchIterator::new(
535-
batches.into_iter(),
536-
schema.clone(),
537-
))
538-
.await
539-
.unwrap();
525+
let _dataset = InsertBuilder::new(&tmp_path)
526+
.execute_stream(data)
527+
.await
528+
.unwrap();
529+
})
530+
.await;
540531

541-
let stats = lance_testing::allocator::get_alloc_stats();
532+
assert_gt!(stats.total_bytes_allocated, 100 * 1024 * 1024);
542533

543534
// The key test: we shouldn't load all 100MB at once
544-
// Allow 5x the batch size to account for overhead and buffering
545-
// (this would still catch if we buffered all 100MB)
546-
let max_allowed_mb = batch_size_mb * 5;
547-
let max_allowed_bytes = (max_allowed_mb * 1024 * 1024) as isize;
535+
// Allow 2x the batch size to account for overhead and buffering
536+
let max_allowed_bytes = (batch_size * 2) as isize;
548537

549538
let peak_mb = stats.max_bytes_allocated as f64 / (1024.0 * 1024.0);
550539

551540
assert!(
552541
stats.max_bytes_allocated <= max_allowed_bytes,
553-
"Peak memory {} MB exceeded limit {} MB",
542+
"Peak memory {:.2} MB exceeded limit {:.2} MB",
554543
peak_mb,
555-
max_allowed_mb
544+
max_allowed_bytes as f64 / (1024.0 * 1024.0)
556545
);
557546
}
558547
}

rust/lance/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,13 @@ pub mod deps {
108108
pub use arrow_schema;
109109
pub use datafusion;
110110
}
111+
112+
#[cfg(test)]
113+
mod tests {
114+
use lance_testing::allocator::TrackingAllocator;
115+
use std::alloc::System;
116+
117+
// Enable allocation statistics for memory tracking tests.
118+
#[global_allocator]
119+
static GLOBAL: TrackingAllocator<System> = TrackingAllocator::new(System);
120+
}

0 commit comments

Comments
 (0)