diff --git a/Cargo.lock b/Cargo.lock index 449f9b17e..8e05239b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1090,6 +1090,7 @@ dependencies = [ "kvm-ioctls", "lazy_static", "libc", + "lockfree", "log", "metrics", "metrics-exporter-prometheus", @@ -1564,6 +1565,15 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "lockfree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74ee94b5ad113c7cb98c5a040f783d0952ee4fe100993881d1673c2cb002dd23" +dependencies = [ + "owned-alloc", +] + [[package]] name = "log" version = "0.4.27" @@ -1921,6 +1931,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "owned-alloc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30fceb411f9a12ff9222c5f824026be368ff15dc2f13468d850c7d3f502205d6" + [[package]] name = "page_size" version = "0.6.0" diff --git a/docs/signal-handlers-development-notes.md b/docs/signal-handlers-development-notes.md index fca9d31a9..e5c48b57b 100644 --- a/docs/signal-handlers-development-notes.md +++ b/docs/signal-handlers-development-notes.md @@ -1,11 +1,12 @@ # Signal Handling in Hyperlight -Hyperlight registers custom signal handlers to intercept and manage specific signals, primarily `SIGSYS` and `SIGRTMIN`. Here's an overview of the registration process: -- **Preserving Old Handlers**: When registering a new signal handler, Hyperlight first retrieves and stores the existing handler using `OnceCell`. This allows Hyperlight to delegate signals to the original handler if necessary. -- **Custom Handlers**: - - **`SIGSYS` Handler**: Captures disallowed syscalls enforced by seccomp. If the signal originates from a hyperlight thread, Hyperlight logs the syscall details. Otherwise, it delegates the signal to the previously registered handler. - - **`SIGRTMIN` Handler**: Utilized for inter-thread signaling, such as execution cancellation. Similar to SIGSYS, it distinguishes between application and non-hyperlight threads to determine how to handle the signal. -- **Thread Differentiation**: Hyperlight uses thread-local storage (IS_HYPERLIGHT_THREAD) to identify whether the current thread is a hyperlight thread. This distinction ensures that signals are handled appropriately based on the thread's role. +Hyperlight registers custom signal handlers to intercept and manage specific signals, primarily `SIGSYS` , `SIGRTMIN` and `SIGSEGV` Here's an overview of the registration process: + +- **Preserving Old Handlers**: When registering a new signal handler, Hyperlight first retrieves and stores the existing handler using either `OnceCell` or a `static AtomicPtr` This allows Hyperlight to delegate signals to the original handler if necessary. +- **Custom Handlers**: +- **`SIGSYS` Handler**: Captures disallowed syscalls enforced by seccomp. If the signal originates from a hyperlight thread, Hyperlight logs the syscall details. Otherwise, it delegates the signal to the previously registered handler. +- **`SIGRTMIN` Handler**: Utilized for inter-thread signaling, such as execution cancellation. Similar to SIGSYS, it distinguishes between application and non-hyperlight threads to determine how to handle the signal. +- **`SIGSEGV` Handler**: Handles segmentation faults for dirty page tracking of host memory mapped into a VM. If the signal applies to an address that is mapped to a VM, it is processed by Hyperlight; otherwise, it is passed to the original handler. ## Potential Issues and Considerations @@ -15,3 +16,14 @@ Hyperlight registers custom signal handlers to intercept and manage specific sig - **Invalidation of `old_handler`**: The stored old_handler reference may no longer point to a valid handler, causing undefined behavior when Hyperlight attempts to delegate signals. - **Loss of Custom Handling**: Hyperlight's custom handler might not be invoked as expected, disrupting its ability to enforce syscall restrictions or manage inter-thread signals. +### Debugging and Signal Handling + +By default when debugging a host application/test/example with GDB or LLDB the debugger will handle the `SIGSEGV` signal by breaking when it is raised, to prevent this and let hyperlight handle the signal enter the following in the debug console: + +#### LLDB + +```process handle SIGSEGV -n true -p true -s false``` + +#### GDB + +```handle SIGSEGV nostop noprint pass``` diff --git a/src/hyperlight_common/src/mem.rs b/src/hyperlight_common/src/mem.rs index 4e5448ae9..4dffc6c4a 100644 --- a/src/hyperlight_common/src/mem.rs +++ b/src/hyperlight_common/src/mem.rs @@ -17,6 +17,8 @@ limitations under the License. pub const PAGE_SHIFT: u64 = 12; pub const PAGE_SIZE: u64 = 1 << 12; pub const PAGE_SIZE_USIZE: usize = 1 << 12; +// The number of pages in 1 "block". A single u64 can be used as bitmap to keep track of all dirty pages in a block. +pub const PAGES_IN_BLOCK: usize = 64; /// A memory region in the guest address space #[derive(Debug, Clone, Copy)] diff --git a/src/hyperlight_host/Cargo.toml b/src/hyperlight_host/Cargo.toml index bc744c05c..ebf395d43 100644 --- a/src/hyperlight_host/Cargo.toml +++ b/src/hyperlight_host/Cargo.toml @@ -44,6 +44,7 @@ anyhow = "1.0" metrics = "0.24.2" serde_json = "1.0" elfcore = "2.0" +lockfree ="0.5" [target.'cfg(windows)'.dependencies] windows = { version = "0.61", features = [ diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index c9160ff52..a02ad96cc 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -79,37 +79,65 @@ fn guest_call_benchmark(c: &mut Criterion) { group.finish(); } -fn guest_call_benchmark_large_param(c: &mut Criterion) { +fn guest_call_benchmark_large_params(c: &mut Criterion) { let mut group = c.benchmark_group("guest_functions_with_large_parameters"); #[cfg(target_os = "windows")] group.sample_size(10); // This benchmark is very slow on Windows, so we reduce the sample size to avoid long test runs. - // This benchmark includes time to first clone a vector and string, so it is not a "pure' benchmark of the guest call, but it's still useful - group.bench_function("guest_call_with_large_parameters", |b| { - const SIZE: usize = 50 * 1024 * 1024; // 50 MB - let large_vec = vec![0u8; SIZE]; - let large_string = unsafe { String::from_utf8_unchecked(large_vec.clone()) }; // Safety: indeed above vec is valid utf8 + // Helper function to create a benchmark for a specific size + let create_benchmark = |group: &mut criterion::BenchmarkGroup<_>, size_mb: usize| { + let benchmark_name = format!("guest_call_with_2_large_parameters_{}mb each", size_mb); + group.bench_function(&benchmark_name, |b| { + let size = size_mb * 1024 * 1024; // Convert MB to bytes + let large_vec = vec![0u8; size]; + let large_string = unsafe { String::from_utf8_unchecked(large_vec.clone()) }; // Safety: indeed above vec is valid utf8 - let mut config = SandboxConfiguration::default(); - config.set_input_data_size(2 * SIZE + (1024 * 1024)); // 2 * SIZE + 1 MB, to allow 1MB for the rest of the serialized function call - config.set_heap_size(SIZE as u64 * 15); + let mut config = SandboxConfiguration::default(); + config.set_input_data_size(2 * size + (1024 * 1024)); - let sandbox = UninitializedSandbox::new( - GuestBinary::FilePath(simple_guest_as_string().unwrap()), - Some(config), - ) - .unwrap(); - let mut sandbox = sandbox.evolve(Noop::default()).unwrap(); + if size < 50 * 1024 * 1024 { + config.set_heap_size(size as u64 * 16); + } else { + config.set_heap_size(size as u64 * 11); // Set to 1GB for larger sizes + } - b.iter(|| { - sandbox - .call_guest_function_by_name::<()>( - "LargeParameters", - (large_vec.clone(), large_string.clone()), - ) - .unwrap() + let sandbox = UninitializedSandbox::new( + GuestBinary::FilePath(simple_guest_as_string().unwrap()), + Some(config), + ) + .unwrap(); + let mut sandbox = sandbox.evolve(Noop::default()).unwrap(); + + b.iter_custom(|iters| { + let mut total_duration = std::time::Duration::new(0, 0); + + for _ in 0..iters { + // Clone the data (not measured) + let vec_clone = large_vec.clone(); + let string_clone = large_string.clone(); + + // Measure only the guest function call + let start = std::time::Instant::now(); + sandbox + .call_guest_function_by_name::<()>( + "LargeParameters", + (vec_clone, string_clone), + ) + .unwrap(); + total_duration += start.elapsed(); + } + + total_duration + }); }); - }); + }; + + // Create benchmarks for different sizes + create_benchmark(&mut group, 5); // 5MB + create_benchmark(&mut group, 10); // 10MB + create_benchmark(&mut group, 20); // 20MB + create_benchmark(&mut group, 40); // 40MB + create_benchmark(&mut group, 60); // 60MB group.finish(); } @@ -153,9 +181,143 @@ fn sandbox_benchmark(c: &mut Criterion) { group.finish(); } +fn sandbox_heap_size_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("sandbox_heap_sizes"); + + // Helper function to create sandbox with specific heap size + let create_sandbox_with_heap_size = |heap_size_mb: Option| { + let path = simple_guest_as_string().unwrap(); + let config = if let Some(size_mb) = heap_size_mb { + let mut config = SandboxConfiguration::default(); + config.set_heap_size(size_mb * 1024 * 1024); // Convert MB to bytes + Some(config) + } else { + None + }; + + let uninit_sandbox = + UninitializedSandbox::new(GuestBinary::FilePath(path), config).unwrap(); + uninit_sandbox.evolve(Noop::default()).unwrap() + }; + + // Benchmark sandbox creation with default heap size + group.bench_function("create_sandbox_default_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(None)); + }); + + // Benchmark sandbox creation with 50MB heap + group.bench_function("create_sandbox_50mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(50))); + }); + + // Benchmark sandbox creation with 100MB heap + group.bench_function("create_sandbox_100mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(100))); + }); + + // Benchmark sandbox creation with 250MB heap + group.bench_function("create_sandbox_250mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(250))); + }); + + // Benchmark sandbox creation with 500MB heap + group.bench_function("create_sandbox_500mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(500))); + }); + + // Benchmark sandbox creation with 995MB heap (close to the limit of 1GB for a Sandbox ) + group.bench_function("create_sandbox_995mb_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(995))); + }); + + group.finish(); +} + +fn guest_call_heap_size_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("guest_call_heap_sizes"); + + // Helper function to create sandbox with specific heap size + let create_sandbox_with_heap_size = |heap_size_mb: Option| { + let path = simple_guest_as_string().unwrap(); + let config = if let Some(size_mb) = heap_size_mb { + let mut config = SandboxConfiguration::default(); + config.set_heap_size(size_mb * 1024 * 1024); // Convert MB to bytes + Some(config) + } else { + None + }; + + let uninit_sandbox = + UninitializedSandbox::new(GuestBinary::FilePath(path), config).unwrap(); + uninit_sandbox.evolve(Noop::default()).unwrap() + }; + + // Benchmark guest function call with default heap size + group.bench_function("guest_call_default_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(None); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 50MB heap + group.bench_function("guest_call_50mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(50)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 100MB heap + group.bench_function("guest_call_100mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(100)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 250MB heap + group.bench_function("guest_call_250mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(250)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 500MB heap + group.bench_function("guest_call_500mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(500)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + // Benchmark guest function call with 995MB heap + group.bench_function("guest_call_995mb_heap", |b| { + let mut sandbox = create_sandbox_with_heap_size(Some(995)); + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap() + }); + }); + + group.finish(); +} + criterion_group! { name = benches; config = Criterion::default(); - targets = guest_call_benchmark, sandbox_benchmark, guest_call_benchmark_large_param + targets = guest_call_benchmark, sandbox_benchmark, sandbox_heap_size_benchmark, guest_call_benchmark_large_params, guest_call_heap_size_benchmark } criterion_main!(benches); diff --git a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs index 90e91f496..f59ce2fe8 100644 --- a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs +++ b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs @@ -29,6 +29,8 @@ use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use log::{LevelFilter, error}; +#[cfg(mshv3)] +use mshv_bindings::MSHV_GPAP_ACCESS_OP_CLEAR; #[cfg(mshv2)] use mshv_bindings::hv_message; use mshv_bindings::{ @@ -76,6 +78,9 @@ use crate::sandbox::SandboxConfiguration; use crate::sandbox::uninitialized::SandboxRuntimeConfig; use crate::{Result, log_then_return, new_error}; +#[cfg(mshv2)] +const CLEAR_DIRTY_BIT_FLAG: u64 = 0b100; + #[cfg(gdb)] mod debug { use std::sync::{Arc, Mutex}; @@ -351,6 +356,7 @@ impl HypervLinuxDriver { vm_fd.initialize()?; vm_fd }; + vm_fd.enable_dirty_page_tracking()?; let mut vcpu_fd = vm_fd.create_vcpu(0)?; @@ -391,13 +397,31 @@ impl HypervLinuxDriver { (None, None) }; + let mut base_pfn = u64::MAX; + let mut total_size: usize = 0; + mem_regions.iter().try_for_each(|region| { - let mshv_region = region.to_owned().into(); + let mshv_region: mshv_user_mem_region = region.to_owned().into(); + if base_pfn == u64::MAX { + base_pfn = mshv_region.guest_pfn; + } + total_size += mshv_region.size as usize; vm_fd.map_user_memory(mshv_region) })?; Self::setup_initial_sregs(&mut vcpu_fd, pml4_ptr.absolute()?)?; + // get/clear the dirty page bitmap, mshv sets all the bit dirty at initialization + // if we dont clear them then we end up taking a complete snapsot of memory page by page which gets + // progressively slower as the sandbox size increases + // the downside of doing this here is that the call to get_dirty_log will takes longer as the number of pages increase + // but for larger sandboxes its easily cheaper than copying all the pages + + #[cfg(mshv2)] + vm_fd.get_dirty_log(base_pfn, total_size, CLEAR_DIRTY_BIT_FLAG)?; + #[cfg(mshv3)] + vm_fd.get_dirty_log(base_pfn, total_size, MSHV_GPAP_ACCESS_OP_CLEAR as u8)?; + let interrupt_handle = Arc::new(LinuxInterruptHandle { running: AtomicU64::new(0), cancel_requested: AtomicBool::new(false), @@ -863,6 +887,27 @@ impl Hypervisor for HypervLinuxDriver { self.interrupt_handle.clone() } + fn get_and_clear_dirty_pages(&mut self) -> Result> { + let first_mshv_region: mshv_user_mem_region = self + .mem_regions + .first() + .ok_or(new_error!( + "tried to get dirty page bitmap of 0-sized region" + ))? + .to_owned() + .into(); + let total_size = self.mem_regions.iter().map(|r| r.guest_region.len()).sum(); + let res = self.vm_fd.get_dirty_log( + first_mshv_region.guest_pfn, + total_size, + #[cfg(mshv2)] + CLEAR_DIRTY_BIT_FLAG, + #[cfg(mshv3)] + (MSHV_GPAP_ACCESS_OP_CLEAR as u8), + )?; + Ok(res) + } + #[cfg(crashdump)] fn crashdump_context(&self) -> Result> { if self.rt_cfg.guest_core_dump { diff --git a/src/hyperlight_host/src/hypervisor/hyperv_windows.rs b/src/hyperlight_host/src/hypervisor/hyperv_windows.rs index cd0398854..0960b4658 100644 --- a/src/hyperlight_host/src/hypervisor/hyperv_windows.rs +++ b/src/hyperlight_host/src/hypervisor/hyperv_windows.rs @@ -55,6 +55,7 @@ use super::{ use super::{HyperlightExit, Hypervisor, InterruptHandle, VirtualCPU}; use crate::hypervisor::fpu::FP_CONTROL_WORD_DEFAULT; use crate::hypervisor::wrappers::WHvGeneralRegisters; +use crate::mem::bitmap::new_page_bitmap; use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; use crate::mem::ptr::{GuestPtr, RawPtr}; #[cfg(crashdump)] @@ -606,6 +607,12 @@ impl Hypervisor for HypervWindowsDriver { Ok(()) } + fn get_and_clear_dirty_pages(&mut self) -> Result> { + // For now we just mark all pages dirty which is the equivalent of taking a full snapshot + let total_size = self.mem_regions.iter().map(|r| r.guest_region.len()).sum(); + new_page_bitmap(total_size, true) + } + #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] unsafe fn map_region(&mut self, _rgn: &MemoryRegion) -> Result<()> { log_then_return!("Mapping host memory into the guest not yet supported on this platform"); diff --git a/src/hyperlight_host/src/hypervisor/kvm.rs b/src/hyperlight_host/src/hypervisor/kvm.rs index 3da9786cd..0887147e3 100644 --- a/src/hyperlight_host/src/hypervisor/kvm.rs +++ b/src/hyperlight_host/src/hypervisor/kvm.rs @@ -21,7 +21,10 @@ use std::sync::Arc; use std::sync::Mutex; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use kvm_bindings::{KVM_MEM_READONLY, kvm_fpu, kvm_regs, kvm_userspace_memory_region}; +use hyperlight_common::mem::{PAGE_SIZE_USIZE, PAGES_IN_BLOCK}; +use kvm_bindings::{ + KVM_MEM_LOG_DIRTY_PAGES, KVM_MEM_READONLY, kvm_fpu, kvm_regs, kvm_userspace_memory_region, +}; use kvm_ioctls::Cap::UserMemory; use kvm_ioctls::{Kvm, VcpuExit, VcpuFd, VmFd}; use log::LevelFilter; @@ -43,7 +46,8 @@ use super::{ use super::{HyperlightExit, Hypervisor, InterruptHandle, LinuxInterruptHandle, VirtualCPU}; #[cfg(gdb)] use crate::HyperlightError; -use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; +use crate::mem::bitmap::{bit_index_iterator, new_page_bitmap}; +use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags, MemoryRegionType}; use crate::mem::ptr::{GuestPtr, RawPtr}; use crate::sandbox::SandboxConfiguration; #[cfg(crashdump)] @@ -284,7 +288,7 @@ mod debug { /// A Hypervisor driver for KVM on Linux pub(crate) struct KVMDriver { _kvm: Kvm, - _vm_fd: VmFd, + vm_fd: VmFd, vcpu_fd: VcpuFd, entrypoint: u64, orig_rsp: GuestPtr, @@ -329,7 +333,7 @@ impl KVMDriver { userspace_addr: region.host_region.start as u64, flags: match perm_flags { MemoryRegionFlags::READ => KVM_MEM_READONLY, - _ => 0, // normal, RWX + _ => KVM_MEM_LOG_DIRTY_PAGES, // normal, RWX }, }; unsafe { vm_fd.set_user_memory_region(kvm_region) } @@ -378,7 +382,7 @@ impl KVMDriver { #[allow(unused_mut)] let mut hv = Self { _kvm: kvm, - _vm_fd: vm_fd, + vm_fd, vcpu_fd, entrypoint, orig_rsp: rsp_gp, @@ -734,6 +738,45 @@ impl Hypervisor for KVMDriver { self.interrupt_handle.clone() } + fn get_and_clear_dirty_pages(&mut self) -> Result> { + let mut page_indices = vec![]; + let mut current_page = 0; + // Iterate over all memory regions and get the dirty pages for each region ignoring guard pages which cannot be dirty + for (i, mem_region) in self.mem_regions.iter().enumerate() { + let num_pages = mem_region.guest_region.len() / PAGE_SIZE_USIZE; + let bitmap = match mem_region.flags { + MemoryRegionFlags::READ => { + // read-only page. It can never be dirty so return zero dirty pages. + new_page_bitmap(mem_region.guest_region.len(), false)? + } + _ => { + if mem_region.region_type == MemoryRegionType::GuardPage { + // Trying to get dirty pages for a guard page region results in a VMMSysError(2) + new_page_bitmap(mem_region.guest_region.len(), false)? + } else { + // Get the dirty bitmap for the memory region + self.vm_fd + .get_dirty_log(i as u32, mem_region.guest_region.len())? + } + } + }; + for page_idx in bit_index_iterator(&bitmap) { + page_indices.push(current_page + page_idx); + } + current_page += num_pages; + } + + // covert vec of page indices to vec of blocks + let mut res = new_page_bitmap(current_page * PAGE_SIZE_USIZE, false)?; + for page_idx in page_indices { + let block_idx = page_idx / PAGES_IN_BLOCK; + let bit_idx = page_idx % PAGES_IN_BLOCK; + res[block_idx] |= 1 << bit_idx; + } + + Ok(res) + } + #[cfg(crashdump)] fn crashdump_context(&self) -> Result> { if self.rt_cfg.guest_core_dump { diff --git a/src/hyperlight_host/src/hypervisor/mod.rs b/src/hyperlight_host/src/hypervisor/mod.rs index ecf6acbc5..9fc276303 100644 --- a/src/hyperlight_host/src/hypervisor/mod.rs +++ b/src/hyperlight_host/src/hypervisor/mod.rs @@ -196,6 +196,11 @@ pub(crate) trait Hypervisor: Debug + Sync + Send { None } + /// Get dirty pages as a bitmap (Vec). + /// Each bit in a u64 represents a page. + /// This also clears the bitflags, marking the pages as non-dirty. + fn get_and_clear_dirty_pages(&mut self) -> Result>; + /// Get InterruptHandle to underlying VM fn interrupt_handle(&self) -> Arc; @@ -506,6 +511,7 @@ pub(crate) mod tests { use super::handlers::{MemAccessHandler, OutBHandler}; #[cfg(gdb)] use crate::hypervisor::DbgMemAccessHandlerCaller; + use crate::mem::dirty_page_tracking::DirtyPageTracking; use crate::mem::ptr::RawPtr; use crate::sandbox::uninitialized::GuestBinary; #[cfg(any(crashdump, gdb))] @@ -557,7 +563,10 @@ pub(crate) mod tests { let rt_cfg: SandboxRuntimeConfig = Default::default(); let sandbox = UninitializedSandbox::new(GuestBinary::FilePath(filename.clone()), Some(config))?; + let tracker = sandbox.tracker.unwrap(); let (_hshm, mut gshm) = sandbox.mgr.build(); + // we need to undo the mprotect(readonly) before mapping memory into vm, and that is done by getting dirty pages + let _ = tracker.get_dirty_pages()?; let mut vm = set_up_hypervisor_partition( &mut gshm, &config, diff --git a/src/hyperlight_host/src/mem/bitmap.rs b/src/hyperlight_host/src/mem/bitmap.rs new file mode 100644 index 000000000..733a4de68 --- /dev/null +++ b/src/hyperlight_host/src/mem/bitmap.rs @@ -0,0 +1,193 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +use std::cmp::Ordering; + +use hyperlight_common::mem::{PAGE_SIZE_USIZE, PAGES_IN_BLOCK}; + +use crate::{Result, log_then_return}; + +// Contains various helper functions for dealing with bitmaps. + +/// Returns a new bitmap of pages. If `init_dirty` is true, all pages are marked as dirty, otherwise all pages are clean. +/// Will return an error if given size is 0. +pub fn new_page_bitmap(size_in_bytes: usize, init_dirty: bool) -> Result> { + if size_in_bytes == 0 { + log_then_return!("Tried to create a bitmap with size 0."); + } + let num_pages = size_in_bytes.div_ceil(PAGE_SIZE_USIZE); + let num_blocks = num_pages.div_ceil(PAGES_IN_BLOCK); + match init_dirty { + false => Ok(vec![0; num_blocks]), + true => { + let mut bitmap = vec![!0u64; num_blocks]; // all pages are dirty + let num_unused_bits = num_blocks * PAGES_IN_BLOCK - num_pages; + // set the unused bits to 0, could cause problems otherwise + #[allow(clippy::unwrap_used)] + let last_block = bitmap.last_mut().unwrap(); // unwrap is safe since size_in_bytes>0 + *last_block >>= num_unused_bits; + Ok(bitmap) + } + } +} + +/// Returns the union (bitwise OR) of two bitmaps. The resulting bitmap will have the same length +/// as the longer of the two input bitmaps. +pub(crate) fn bitmap_union(bitmap: &[u64], other_bitmap: &[u64]) -> Vec { + let min_len = bitmap.len().min(other_bitmap.len()); + let max_len = bitmap.len().max(other_bitmap.len()); + + let mut result = vec![0; max_len]; + + for i in 0..min_len { + result[i] = bitmap[i] | other_bitmap[i]; + } + + match bitmap.len().cmp(&other_bitmap.len()) { + Ordering::Greater => { + result[min_len..].copy_from_slice(&bitmap[min_len..]); + } + Ordering::Less => { + result[min_len..].copy_from_slice(&other_bitmap[min_len..]); + } + Ordering::Equal => {} + } + + result +} + +// Used as a helper struct to implement an iterator on. +struct SetBitIndices<'a> { + bitmap: &'a [u64], + block_index: usize, // one block is 1 u64, which is 64 pages + current: u64, // the current block we are iterating over, or 0 if first iteration +} + +/// Iterates over the zero-based indices of the set bits in the given bitmap. +pub(crate) fn bit_index_iterator(bitmap: &[u64]) -> impl Iterator + '_ { + SetBitIndices { + bitmap, + block_index: 0, + current: 0, + } +} + +impl Iterator for SetBitIndices<'_> { + type Item = usize; + + fn next(&mut self) -> Option { + while self.current == 0 { + // will always enter this on first iteration because current is initialized to 0 + if self.block_index >= self.bitmap.len() { + // no more blocks to iterate over + return None; + } + self.current = self.bitmap[self.block_index]; + self.block_index += 1; + } + let trailing_zeros = self.current.trailing_zeros(); + self.current &= self.current - 1; // Clear the least significant set bit + Some((self.block_index - 1) * 64 + trailing_zeros as usize) // block_index guaranteed to be > 0 at this point + } +} + +#[cfg(test)] +mod tests { + use hyperlight_common::mem::PAGE_SIZE_USIZE; + + use crate::Result; + use crate::mem::bitmap::{bit_index_iterator, bitmap_union, new_page_bitmap}; + + #[test] + fn new_page_bitmap_test() -> Result<()> { + let bitmap = new_page_bitmap(1, false)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 0); + + let bitmap = new_page_bitmap(1, true)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 1); + + let bitmap = new_page_bitmap(32 * PAGE_SIZE_USIZE, false)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 0); + + let bitmap = new_page_bitmap(32 * PAGE_SIZE_USIZE, true)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 0x0000_0000_FFFF_FFFF); + Ok(()) + } + + #[test] + fn page_iterator() { + let data = vec![0b1000010100, 0b01, 0b100000000000000011]; + let mut iter = bit_index_iterator(&data); + assert_eq!(iter.next(), Some(2)); + assert_eq!(iter.next(), Some(4)); + assert_eq!(iter.next(), Some(9)); + assert_eq!(iter.next(), Some(64)); + assert_eq!(iter.next(), Some(128)); + assert_eq!(iter.next(), Some(129)); + assert_eq!(iter.next(), Some(145)); + assert_eq!(iter.next(), None); + + let data_2 = vec![0, 0, 0]; + let mut iter_2 = bit_index_iterator(&data_2); + assert_eq!(iter_2.next(), None); + + let data_3 = vec![0, 0, 0b1, 1 << 63]; + let mut iter_3 = bit_index_iterator(&data_3); + assert_eq!(iter_3.next(), Some(128)); + assert_eq!(iter_3.next(), Some(255)); + assert_eq!(iter_3.next(), None); + + let data_4 = vec![]; + let mut iter_4 = bit_index_iterator(&data_4); + assert_eq!(iter_4.next(), None); + } + + #[test] + fn union() -> Result<()> { + let a = 0b1000010100; + let b = 0b01; + let c = 0b100000000000000011; + let d = 0b101010100000011000000011; + let e = 0b000000000000001000000000000000000000; + let f = 0b100000000000000001010000000001010100000000000; + let bitmap = vec![a, b, c]; + let other_bitmap = vec![d, e, f]; + let union = bitmap_union(&bitmap, &other_bitmap); + assert_eq!(union, vec![a | d, b | e, c | f]); + + // different length + let union = bitmap_union(&[a], &[d, e, f]); + assert_eq!(union, vec![a | d, e, f]); + + let union = bitmap_union(&[a, b, c], &[d]); + assert_eq!(union, vec![a | d, b, c]); + + let union = bitmap_union(&[], &[d, e]); + assert_eq!(union, vec![d, e]); + + let union = bitmap_union(&[a, b, c], &[]); + assert_eq!(union, vec![a, b, c]); + + let union = bitmap_union(&[], &[]); + let empty: Vec = vec![]; + assert_eq!(union, empty); + + Ok(()) + } +} diff --git a/src/hyperlight_host/src/mem/dirty_page_tracking.rs b/src/hyperlight_host/src/mem/dirty_page_tracking.rs new file mode 100644 index 000000000..bcc65848a --- /dev/null +++ b/src/hyperlight_host/src/mem/dirty_page_tracking.rs @@ -0,0 +1,49 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use tracing::{Span, instrument}; + +#[cfg(target_os = "linux")] +pub use super::linux_dirty_page_tracker::LinuxDirtyPageTracker as PlatformDirtyPageTracker; +use super::shared_mem::SharedMemory; +#[cfg(target_os = "windows")] +pub use super::windows_dirty_page_tracker::WindowsDirtyPageTracker as PlatformDirtyPageTracker; +use crate::Result; + +/// Trait defining the interface for dirty page tracking implementations +pub trait DirtyPageTracking { + fn get_dirty_pages(self) -> Result>; +} + +/// Cross-platform dirty page tracker that delegates to platform-specific implementations +pub struct DirtyPageTracker { + inner: PlatformDirtyPageTracker, +} + +impl DirtyPageTracker { + /// Create a new dirty page tracker for the given shared memory + #[instrument(skip_all, parent = Span::current(), level = "Trace")] + pub fn new(shared_memory: &T) -> Result { + let inner = PlatformDirtyPageTracker::new(shared_memory)?; + Ok(Self { inner }) + } +} + +impl DirtyPageTracking for DirtyPageTracker { + fn get_dirty_pages(self) -> Result> { + self.inner.get_dirty_pages() + } +} diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index 04edc9bcc..aafcffc52 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -246,7 +246,7 @@ impl SandboxMemoryLayout { pub(crate) const BASE_ADDRESS: usize = 0x0; // the offset into a sandbox's input/output buffer where the stack starts - const STACK_POINTER_SIZE_BYTES: u64 = 8; + pub(crate) const STACK_POINTER_SIZE_BYTES: u64 = 8; /// Create a new `SandboxMemoryLayout` with the given /// `SandboxConfiguration`, code size and stack/heap size. @@ -397,7 +397,7 @@ impl SandboxMemoryLayout { /// Get the offset in guest memory to the output data pointer. #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_output_data_pointer_offset(&self) -> usize { + pub(super) fn get_output_data_pointer_offset(&self) -> usize { // This field is immediately after the output data size field, // which is a `u64`. self.get_output_data_size_offset() + size_of::() @@ -429,7 +429,7 @@ impl SandboxMemoryLayout { /// Get the offset in guest memory to the input data pointer. #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_input_data_pointer_offset(&self) -> usize { + pub(super) fn get_input_data_pointer_offset(&self) -> usize { // The input data pointer is immediately after the input // data size field in the input data `GuestMemoryRegion` struct which is a `u64`. self.get_input_data_size_offset() + size_of::() diff --git a/src/hyperlight_host/src/mem/linux_dirty_page_tracker.rs b/src/hyperlight_host/src/mem/linux_dirty_page_tracker.rs new file mode 100644 index 000000000..5336992e4 --- /dev/null +++ b/src/hyperlight_host/src/mem/linux_dirty_page_tracker.rs @@ -0,0 +1,1291 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use std::ptr; +use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering}; +use std::sync::{Arc, OnceLock}; + +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use libc::{PROT_READ, PROT_WRITE, mprotect}; +use lockfree::map::Map; +use log::error; + +use crate::mem::shared_mem::{HostMapping, SharedMemory}; +use crate::{Result, new_error}; + +// Tracker metadata stored in global lock-free storage +struct TrackerData { + pid: u32, + base_addr: usize, + size: usize, + num_pages: usize, + dirty_pages: Vec, +} + +// Global lock-free collection to store tracker data for signal handler to access + +static TRACKERS: OnceLock> = OnceLock::new(); + +// Helper function to get or initialize the global trackers map +// lockfree::Map is truly lock-free and safe for signal handlers +fn get_trackers() -> &'static Map { + TRACKERS.get_or_init(Map::new) +} + +/// Global tracker ID counter +static NEXT_TRACKER_ID: AtomicUsize = AtomicUsize::new(1); + +/// Original SIGSEGV handler to chain to (stored atomically for async signal safety) +static ORIGINAL_SIGSEGV_HANDLER: AtomicPtr = AtomicPtr::new(ptr::null_mut()); + +/// Whether our SIGSEGV handler is installed +static HANDLER_INSTALLED: AtomicBool = AtomicBool::new(false); + +/// Dirty page tracker for Linux +/// This tracks which pages have been written for a memory region once new has been called +/// It marks pages as RO and then uses SIGSEGV to detect writes to pages, then updates the page to RW and notes the page index as dirty by writing details to global lock-free storage +/// +/// A user calls get_dirty_pages to get a list of dirty pages to get details of the pages that were written to since the tracker was created +/// +/// Once a user has called get_dirty_pages, this tracker is destroyed and will not track changes any longer +#[derive(Debug)] +pub struct LinuxDirtyPageTracker { + /// Unique ID for this tracker + id: usize, + /// Base address of the memory region being tracked + base_addr: usize, + /// Size of the memory region in bytes + size: usize, + /// Keep a reference to the HostMapping to ensure memory lifetime + _mapping: Arc, +} + +// DirtyPageTracker should be Send because: +// 1. The Arc ensures the memory stays valid +// 2. The tracker handles synchronization properly +// 3. This is needed for threaded sandbox initialization +unsafe impl Send for LinuxDirtyPageTracker {} + +impl LinuxDirtyPageTracker { + /// Create a new dirty page tracker for the given shared memory + pub(super) fn new(shared_memory: &T) -> Result { + let mapping = shared_memory.region_arc(); + let base_addr = shared_memory.base_addr(); + let size = shared_memory.mem_size(); + + if size == 0 { + return Err(new_error!("Cannot track empty memory region")); + } + + if base_addr % PAGE_SIZE_USIZE != 0 { + return Err(new_error!("Base address must be page-aligned")); + } + + // Get the current process ID + let current_pid = std::process::id(); + + // Check that there is not already a tracker that includes this address range + // within the same process (virtual addresses are only unique per process) + for guard in get_trackers().iter() { + let tracker_data = guard.val(); + + // Only check for overlaps within the same process + if tracker_data.pid == current_pid { + let existing_start = tracker_data.base_addr; + let existing_end = tracker_data.base_addr + tracker_data.size; + let new_start = base_addr; + let new_end = base_addr + size; + + // Check for overlap: two ranges [a,b) and [c,d) overlap if max(a,c) < min(b,d) + // Equivalently: they DON'T overlap if b <= c || d <= a + // So they DO overlap if !(b <= c || d <= a) which is (b > c && d > a) + if new_end > existing_start && existing_end > new_start { + return Err(new_error!( + "Address range [{:#x}, {:#x}) overlaps with existing tracker [{:#x}, {:#x}) in process {}", + new_start, + new_end, + existing_start, + existing_end, + current_pid + )); + } + } + } + + let num_pages = size.div_ceil(PAGE_SIZE_USIZE); + let id = NEXT_TRACKER_ID.fetch_add(1, Ordering::Relaxed); + + // Create atomic array for dirty page tracking + let dirty_pages: Vec = (0..num_pages).map(|_| AtomicBool::new(false)).collect(); + + // Create tracker data + let tracker_data = TrackerData { + pid: current_pid, + base_addr, + size, + num_pages, + dirty_pages, + }; + + // Install global SIGSEGV handler if not already installed + Self::ensure_sigsegv_handler_installed()?; + + // Write protect the memory region to make it read-only so we get SIGSEGV on writes + let result = unsafe { mprotect(base_addr as *mut libc::c_void, size, PROT_READ) }; + + if result != 0 { + return Err(new_error!( + "Failed to write-protect memory for dirty tracking: {}", + std::io::Error::last_os_error() + )); + } + + get_trackers().insert(id, tracker_data); + + Ok(Self { + id, + base_addr, + size, + _mapping: mapping, + }) + } + + /// Get all dirty page indices for this tracker. + /// NOTE: This is not a bitmap, but a vector of indices where each index corresponds to a page that has been written to. + pub(super) fn get_dirty_pages(self) -> Result> { + let res: Vec = if let Some(tracker_data) = get_trackers().get(&self.id) { + let mut dirty_pages = Vec::new(); + let tracker_data = tracker_data.val(); + for (idx, dirty) in tracker_data.dirty_pages.iter().enumerate() { + if dirty.load(Ordering::Acquire) { + dirty_pages.push(idx); + } + } + dirty_pages + } else { + return Err(new_error!( + "Tried to get dirty pages from tracker, but no tracker data found" + )); + }; + + // explicit to document intent + drop(self); + + Ok(res) + } + + #[cfg(test)] + /// Check if a memory address falls within this tracker's region + fn contains_address(&self, addr: usize) -> bool { + addr >= self.base_addr && addr < self.base_addr + self.size + } + + /// Install global SIGSEGV handler if not already installed + fn ensure_sigsegv_handler_installed() -> Result<()> { + // Use compare_exchange to ensure only one thread does the installation + match HANDLER_INSTALLED.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) { + Ok(_) => { + // We won the race - we're responsible for installation + + // Get the current handler before installing ours + let mut original = Box::new(unsafe { std::mem::zeroed::() }); + + unsafe { + let result = libc::sigaction( + libc::SIGSEGV, + std::ptr::null(), + original.as_mut() as *mut libc::sigaction, + ); + + if result != 0 { + // Reset the flag on error + HANDLER_INSTALLED.store(false, Ordering::Release); + return Err(new_error!( + "Failed to get original SIGSEGV handler: {}", + std::io::Error::last_os_error() + )); + } + } + + // Install our handler + if let Err(e) = vmm_sys_util::signal::register_signal_handler( + libc::SIGSEGV, + Self::sigsegv_handler, + ) { + // Reset the flag on error + HANDLER_INSTALLED.store(false, Ordering::Release); + return Err(new_error!("Failed to register SIGSEGV handler: {}", e)); + } + + // Store original handler pointer atomically + let original_ptr = Box::into_raw(original); + ORIGINAL_SIGSEGV_HANDLER.store(original_ptr, Ordering::Release); + + Ok(()) + } + Err(_) => { + // Another thread already installed it, we're done + Ok(()) + } + } + } + + /// MINIMAL async signal safe SIGSEGV handler for dirty page tracking + /// This handler uses only async signal safe operations: + /// - Atomic loads/stores + /// - mprotect (async signal safe) + /// - Simple pointer arithmetic + /// - global lock-free storage (lockfree::Map) + /// - `getpid()` to check process ownership + extern "C" fn sigsegv_handler( + signal: libc::c_int, + info: *mut libc::siginfo_t, + context: *mut libc::c_void, + ) { + unsafe { + if signal != libc::SIGSEGV || info.is_null() { + Self::call_original_handler(signal, info, context); + return; + } + + let fault_addr = (*info).si_addr() as usize; + + // Check all trackers in global lock-free storage + // lockfree::Map::iter() is guaranteed to be async-signal-safe + let mut handled = false; + for guard in get_trackers().iter() { + let tracker_data = guard.val(); + + // Only handle faults for trackers in the current process + // We compare the stored PID with the current process PID + // getpid() is async-signal-safe, but we can avoid the call by checking + // if the fault address is within this tracker's range first + if fault_addr < tracker_data.base_addr + || fault_addr >= tracker_data.base_addr + tracker_data.size + { + continue; // Fault not in this tracker's range + } + + // Now verify this tracker belongs to the current process + let current_pid = libc::getpid() as u32; + if tracker_data.pid != current_pid { + continue; + } + + // We know the fault is in this tracker's range and it's our process + // Calculate page index + let page_offset = fault_addr - tracker_data.base_addr; + let page_idx = page_offset / PAGE_SIZE_USIZE; + + if page_idx < tracker_data.num_pages { + // Mark page dirty atomically (async signal safe) + tracker_data.dirty_pages[page_idx].store(true, Ordering::Relaxed); + + // Make page writable (mprotect is async signal safe) + let page_addr = tracker_data.base_addr + (page_idx * PAGE_SIZE_USIZE); + let result = mprotect( + page_addr as *mut libc::c_void, + PAGE_SIZE_USIZE, + PROT_READ | PROT_WRITE, + ); + + handled = result == 0; + break; // Found the tracker, stop searching + } + } + + // If not handled by any of our trackers, chain to original handler + if !handled { + Self::call_original_handler(signal, info, context); + } + } + } + + /// Call the original SIGSEGV handler if available (async signal safe) + fn call_original_handler( + signal: libc::c_int, + info: *mut libc::siginfo_t, + context: *mut libc::c_void, + ) { + unsafe { + let handler_ptr = ORIGINAL_SIGSEGV_HANDLER.load(Ordering::Acquire); + if !handler_ptr.is_null() { + let original = &*handler_ptr; + if original.sa_sigaction != 0 { + let handler_fn: extern "C" fn( + libc::c_int, + *mut libc::siginfo_t, + *mut libc::c_void, + ) = std::mem::transmute(original.sa_sigaction); + handler_fn(signal, info, context); + } + } + } + } +} + +impl Drop for LinuxDirtyPageTracker { + fn drop(&mut self) { + // Remove this tracker's metadata from global lock-free storage + if get_trackers().remove(&self.id).is_none() { + error!("Tracker {} not found in global storage", self.id); + } + + // Restore memory protection + unsafe { + mprotect( + self.base_addr as *mut libc::c_void, + self.size, + PROT_READ | PROT_WRITE, + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + use std::ptr::null_mut; + use std::sync::{Arc, Barrier}; + use std::thread; + + use libc::{MAP_ANONYMOUS, MAP_FAILED, MAP_PRIVATE, PROT_READ, PROT_WRITE, mmap, munmap}; + use rand::{Rng, rng}; + + use super::*; + use crate::mem::shared_mem::{HostMapping, SharedMemory}; + + const PAGE_SIZE: usize = 4096; + + /// Helper function to create a tracker from raw memory parameters + fn create_test_tracker(base_addr: usize, size: usize) -> Result { + let test_memory = TestSharedMemory::new(base_addr, size); + LinuxDirtyPageTracker::new(&test_memory) + } + + /// Test implementation of SharedMemory for raw memory regions + struct TestSharedMemory { + mapping: Arc, + base_addr: usize, + size: usize, + } + + impl TestSharedMemory { + fn new(base_addr: usize, size: usize) -> Self { + // Create a real ExclusiveSharedMemory and extract its mapping + // This ensures we have a proper HostMapping for testing + let total_size = size + 2 * PAGE_SIZE_USIZE; + let exclusive = crate::mem::shared_mem::ExclusiveSharedMemory::new(total_size).unwrap(); + let mapping = exclusive.region_arc(); + + Self { + mapping, + base_addr, + size, + } + } + } + + impl SharedMemory for TestSharedMemory { + fn region(&self) -> &HostMapping { + &self.mapping + } + + fn region_arc(&self) -> Arc { + Arc::clone(&self.mapping) + } + + fn base_addr(&self) -> usize { + self.base_addr + } + + fn mem_size(&self) -> usize { + self.size + } + + fn with_exclusivity< + T, + F: FnOnce(&mut crate::mem::shared_mem::ExclusiveSharedMemory) -> T, + >( + &mut self, + _f: F, + ) -> crate::Result { + unimplemented!("TestSharedMemory doesn't support with_exclusivity") + } + } + + /// Helper to create page-aligned memory for testing + /// Returns (pointer, size) tuple + fn create_aligned_memory(size: usize) -> (*mut u8, usize) { + let addr = unsafe { + mmap( + null_mut(), + size, + PROT_READ | PROT_WRITE, + MAP_ANONYMOUS | MAP_PRIVATE, + -1, + 0, + ) + }; + + if addr == MAP_FAILED { + panic!("Failed to allocate aligned memory with mmap"); + } + + (addr as *mut u8, size) + } + + /// Helper to clean up mmap'd memory + unsafe fn free_aligned_memory(ptr: *mut u8, size: usize) { + if unsafe { munmap(ptr as *mut libc::c_void, size) } != 0 { + eprintln!("Warning: Failed to unmap memory"); + } + } + + #[test] + fn test_tracker_creation() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 4); + let addr = memory_ptr as usize; + + let test_memory = TestSharedMemory::new(addr, memory_size); + let tracker = LinuxDirtyPageTracker::new(&test_memory); + println!("Tracker created: {:?}", tracker); + assert!(tracker.is_ok()); + let tracker = tracker.unwrap(); + + // Explicitly drop tracker before freeing memory + drop(tracker); + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_zero_size_memory_fails() { + let addr = 0x1000; // Page-aligned address + let test_memory = TestSharedMemory::new(addr, 0); + let result = LinuxDirtyPageTracker::new(&test_memory); + assert!(result.is_err()); + } + + #[test] + fn test_unaligned_address_fails() { + let unaligned_addr = 0x1001; // Not page-aligned + let size = PAGE_SIZE; + let test_memory = TestSharedMemory::new(unaligned_addr, size); + let result = LinuxDirtyPageTracker::new(&test_memory); + assert!(result.is_err()); + } + + #[test] + fn test_overlapping_trackers_all_fail() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 20); // Large enough for all test cases + let base_memory_addr = memory_ptr as usize; + + // Define test cases for different overlap scenarios + // Each test case: (existing_offset, existing_size, new_offset, new_size, description) + let test_cases = vec![ + // Case 1: New range completely overlaps existing (new contains existing) + ( + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 2, + PAGE_SIZE * 8, + "new contains existing", + ), + // Case 2: New range completely contained by existing (existing contains new) + ( + PAGE_SIZE * 2, + PAGE_SIZE * 8, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + "existing contains new", + ), + // Case 3: New range overlaps start of existing + ( + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 2, + PAGE_SIZE * 4, + "new overlaps start of existing", + ), + // Case 4: New range overlaps end of existing + ( + PAGE_SIZE * 2, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + "new overlaps end of existing", + ), + // Case 5: New range exactly matches existing + ( + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + "new exactly matches existing", + ), + // Case 6: New range starts at same address but different size + ( + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 2, + "new starts same, smaller size", + ), + ( + PAGE_SIZE * 4, + PAGE_SIZE * 2, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + "new starts same, larger size", + ), + // Case 7: New range ends at same address but different start + ( + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 6, + PAGE_SIZE * 2, + "new ends same, different start", + ), + ( + PAGE_SIZE * 6, + PAGE_SIZE * 2, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + "new ends same, earlier start", + ), + // Case 8: Single page overlaps + ( + PAGE_SIZE * 4, + PAGE_SIZE, + PAGE_SIZE * 4, + PAGE_SIZE, + "single page exact match", + ), + ( + PAGE_SIZE * 4, + PAGE_SIZE * 2, + PAGE_SIZE * 5, + PAGE_SIZE, + "single page within existing", + ), + // Case 9: Multi-page overlaps + ( + PAGE_SIZE * 5, + PAGE_SIZE * 3, + PAGE_SIZE * 3, + PAGE_SIZE * 4, + "multi-page partial overlap start", + ), + ( + PAGE_SIZE * 3, + PAGE_SIZE * 4, + PAGE_SIZE * 5, + PAGE_SIZE * 3, + "multi-page partial overlap end", + ), + ]; + + for (i, (existing_offset, existing_size, new_offset, new_size, description)) in + test_cases.iter().enumerate() + { + println!("Test case {}: {}", i + 1, description); + + let existing_addr = base_memory_addr + existing_offset; + let new_addr = base_memory_addr + new_offset; + + println!( + " Existing: [{:#x}, {:#x}) (size: {})", + existing_addr, + existing_addr + existing_size, + existing_size + ); + println!( + " New: [{:#x}, {:#x}) (size: {})", + new_addr, + new_addr + new_size, + new_size + ); + + // Create the first tracker + let test_memory1 = TestSharedMemory::new(existing_addr, *existing_size); + let tracker1 = LinuxDirtyPageTracker::new(&test_memory1); + assert!( + tracker1.is_ok(), + "Failed to create first tracker for test case: {}", + description + ); + let tracker1 = tracker1.unwrap(); + + // Try to create overlapping tracker - this should fail + let test_memory2 = TestSharedMemory::new(new_addr, *new_size); + let tracker2_result = LinuxDirtyPageTracker::new(&test_memory2); + assert!( + tracker2_result.is_err(), + "Expected overlapping tracker to fail for test case: {}\n Existing: [{:#x}, {:#x})\n New: [{:#x}, {:#x})", + description, + existing_addr, + existing_addr + existing_size, + new_addr, + new_addr + new_size + ); + + println!(" ✓ Correctly rejected overlap"); + + // Clean up by dropping the tracker + drop(tracker1); + println!(); + } + + // Test cases that should NOT overlap (adjacent ranges) + let non_overlapping_cases = [ + // Case 1: Adjacent ranges (end of first == start of second) + ( + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 8, + PAGE_SIZE * 4, + "adjacent ranges (end to start)", + ), + // Case 2: Adjacent ranges (end of second == start of first) + ( + PAGE_SIZE * 8, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + PAGE_SIZE * 4, + "adjacent ranges (start to end)", + ), + // Case 3: Completely separate ranges + ( + PAGE_SIZE * 2, + PAGE_SIZE * 2, + PAGE_SIZE * 6, + PAGE_SIZE * 2, + "completely separate ranges", + ), + ( + PAGE_SIZE * 10, + PAGE_SIZE * 2, + PAGE_SIZE * 2, + PAGE_SIZE * 2, + "completely separate ranges (reversed)", + ), + ]; + + println!("Testing non-overlapping cases (these should succeed):"); + for (i, (existing_offset, existing_size, new_offset, new_size, description)) in + non_overlapping_cases.iter().enumerate() + { + println!("Non-overlap test case {}: {}", i + 1, description); + + let existing_addr = base_memory_addr + existing_offset; + let new_addr = base_memory_addr + new_offset; + + println!( + " Existing: [{:#x}, {:#x}) (size: {})", + existing_addr, + existing_addr + existing_size, + existing_size + ); + println!( + " New: [{:#x}, {:#x}) (size: {})", + new_addr, + new_addr + new_size, + new_size + ); + + // Create the first tracker + let test_memory1 = TestSharedMemory::new(existing_addr, *existing_size); + let tracker1 = LinuxDirtyPageTracker::new(&test_memory1); + assert!( + tracker1.is_ok(), + "Failed to create first tracker for non-overlap test: {}", + description + ); + let tracker1 = tracker1.unwrap(); + + // Try to create non-overlapping tracker - this should succeed + let test_memory2 = TestSharedMemory::new(new_addr, *new_size); + let tracker2_result = LinuxDirtyPageTracker::new(&test_memory2); + assert!( + tracker2_result.is_ok(), + "Expected non-overlapping tracker to succeed for test case: {}\n Existing: [{:#x}, {:#x})\n New: [{:#x}, {:#x})", + description, + existing_addr, + existing_addr + existing_size, + new_addr, + new_addr + new_size + ); + + let tracker2 = tracker2_result.unwrap(); + println!(" ✓ Correctly allowed non-overlapping ranges"); + + // Clean up + drop(tracker1); + drop(tracker2); + println!(); + } + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_three_way_overlap_detection() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 15); + let base_addr = memory_ptr as usize; + + // Create two non-overlapping trackers first + let tracker1 = create_test_tracker(base_addr + PAGE_SIZE * 2, PAGE_SIZE * 3).unwrap(); + let tracker2 = create_test_tracker(base_addr + PAGE_SIZE * 8, PAGE_SIZE * 3).unwrap(); + + // Try to create a tracker that overlaps with tracker1 + let overlap_with_1 = create_test_tracker(base_addr + PAGE_SIZE * 3, PAGE_SIZE * 3); + assert!( + overlap_with_1.is_err(), + "Should reject overlap with first tracker" + ); + + // Try to create a tracker that overlaps with tracker2 + let overlap_with_2 = create_test_tracker(base_addr + PAGE_SIZE * 7, PAGE_SIZE * 3); + assert!( + overlap_with_2.is_err(), + "Should reject overlap with second tracker" + ); + + // Try to create a tracker that spans both (overlaps with both) + let overlap_with_both = create_test_tracker(base_addr + PAGE_SIZE * 4, PAGE_SIZE * 6); + assert!( + overlap_with_both.is_err(), + "Should reject overlap with both trackers" + ); + + // Create a tracker that doesn't overlap with either (should succeed) + let no_overlap = create_test_tracker(base_addr + PAGE_SIZE * 12, PAGE_SIZE * 2); + assert!(no_overlap.is_ok(), "Should allow non-overlapping tracker"); + + // Explicitly drop all trackers before freeing memory + drop(tracker1); + drop(tracker2); + drop(no_overlap); + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_get_dirty_pages_initially_empty() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 4); + let addr = memory_ptr as usize; + + let tracker = create_test_tracker(addr, memory_size).unwrap(); + let dirty_pages = tracker.get_dirty_pages().unwrap(); + assert!(dirty_pages.is_empty()); + + // tracker is already dropped by get_dirty_pages() call above + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_random_page_dirtying() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 10); + let addr = memory_ptr as usize; + + let tracker = create_test_tracker(addr, memory_size).unwrap(); + + // Simulate random page access by directly writing to memory + // This should trigger the SIGSEGV handler and mark pages as dirty + + // generate 5 random page indices to dirty + let mut pages_to_dirty: HashSet = HashSet::new(); + while pages_to_dirty.len() < 5 { + let page_idx = rand::random::() % 10; // 0 to 9 + pages_to_dirty.insert(page_idx as usize); + } + + for &page_idx in &pages_to_dirty { + let page_offset = page_idx * PAGE_SIZE; + if page_offset < memory_size { + // Write to the memory to trigger dirty tracking + unsafe { + let write_addr = (addr + page_offset + 100) as *mut u8; + std::ptr::write_volatile(write_addr, 42); + } + } + } + + let dirty_pages = tracker.get_dirty_pages().unwrap(); + + println!("Dirty Pages expected: {:?}", pages_to_dirty); + println!("Dirty pages found: {:?}", dirty_pages); + + // check that the dirty pages only contain the indices we wrote to + for &page_idx in &pages_to_dirty { + assert!( + dirty_pages.contains(&page_idx), + "Page {} should be dirty", + page_idx + ); + } + // Check that no other pages are dirty + for &page_idx in &dirty_pages { + assert!( + pages_to_dirty.contains(&page_idx), + "Unexpected dirty page: {}", + page_idx + ); + } + + // tracker is already dropped by get_dirty_pages() call above + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_multiple_trackers_different_regions() { + let (memory_ptr1, memory_size1) = create_aligned_memory(PAGE_SIZE * 4); + let (memory_ptr2, memory_size2) = create_aligned_memory(PAGE_SIZE * 4); + let addr1 = memory_ptr1 as usize; + let addr2 = memory_ptr2 as usize; + + let tracker1 = create_test_tracker(addr1, memory_size1).unwrap(); + let tracker2 = create_test_tracker(addr2, memory_size2).unwrap(); + + // Write to different memory regions + unsafe { + std::ptr::write_volatile((addr1 + 100) as *mut u8, 1); + std::ptr::write_volatile((addr2 + PAGE_SIZE + 200) as *mut u8, 2); + } + + let dirty1 = tracker1.get_dirty_pages().unwrap(); + let dirty2 = tracker2.get_dirty_pages().unwrap(); + + // Verify each tracker only reports pages that were actually written to + // Tracker1: wrote to offset 100, which is in page 0 + assert!(dirty1.contains(&0), "Tracker 1 should have page 0 dirty"); + assert_eq!(dirty1.len(), 1, "Tracker 1 should only have 1 dirty page"); + + // Tracker2: wrote to offset PAGE_SIZE + 200, which is in page 1 + assert!(dirty2.contains(&1), "Tracker 2 should have page 1 dirty"); + assert_eq!(dirty2.len(), 1, "Tracker 2 should only have 1 dirty page"); + + // Verify that each tracker's dirty pages are within expected bounds + for &page_idx in &dirty1 { + assert!( + page_idx < 4, + "Tracker 1 page index {} out of bounds", + page_idx + ); + } + for &page_idx in &dirty2 { + assert!( + page_idx < 4, + "Tracker 2 page index {} out of bounds", + page_idx + ); + } + + unsafe { + free_aligned_memory(memory_ptr1, memory_size1); + free_aligned_memory(memory_ptr2, memory_size2); + } + } + + #[test] + fn test_cleanup_on_drop() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 2); + let addr = memory_ptr as usize; + + // Create tracker in a scope to test drop behavior + { + let tracker = create_test_tracker(addr, memory_size).unwrap(); + + // Write to memory to verify tracking works + unsafe { + std::ptr::write_volatile((addr + 100) as *mut u8, 42); + } + + let _ = tracker.get_dirty_pages(); + } // tracker is dropped here + + // Create a new tracker for the same memory region + // This should work without issues if data was properly cleaned up + let new_tracker = create_test_tracker(addr, memory_size); + assert!( + new_tracker.is_ok(), + "Data not properly cleaned up on tracker drop" + ); + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_page_boundaries() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 3); + let addr = memory_ptr as usize; + + let tracker = create_test_tracker(addr, memory_size).unwrap(); + + // Write to different offsets within the first page + let offsets = [0, 1, 100, 1000, PAGE_SIZE - 1]; + + for &offset in &offsets { + unsafe { + std::ptr::write_volatile((addr + offset) as *mut u8, offset as u8); + } + } + + let dirty_pages = tracker.get_dirty_pages().unwrap(); + + // All writes to the same page should result in the same page being dirty + if !dirty_pages.is_empty() { + // Check that page indices are within bounds + for &page_idx in &dirty_pages { + assert!(page_idx < 3, "Page index out of bounds: {}", page_idx); + } + } + + // tracker is already dropped by get_dirty_pages() call above + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_concurrent_trackers() { + const NUM_THREADS: usize = 50; + const UPDATES_PER_THREAD: usize = 500; + const MIN_MEMORY_SIZE: usize = 1024 * 1024; // 1MB + const MAX_MEMORY_SIZE: usize = 10 * 1024 * 1024; // 10MB + + // Create barrier for synchronization + let start_writing_barrier = Arc::new(Barrier::new(NUM_THREADS)); + + let mut handles = Vec::new(); + + for thread_id in 0..NUM_THREADS { + let start_writing_barrier = Arc::clone(&start_writing_barrier); + + let handle = thread::spawn(move || { + let mut rng = rng(); + + // Generate random memory size between 1MB and 10MB + let memory_size = rng.random_range(MIN_MEMORY_SIZE..=MAX_MEMORY_SIZE); + + // Ensure memory size is page-aligned + let memory_size = (memory_size + PAGE_SIZE - 1) & !(PAGE_SIZE - 1); + let num_pages = memory_size / PAGE_SIZE; + + let (memory_ptr, _) = create_aligned_memory(memory_size); + let addr = memory_ptr as usize; + + // Create tracker (must succeed) + let tracker = + create_test_tracker(addr, memory_size).expect("Failed to create tracker"); + + // Wait for all threads to finish allocating before starting writes + start_writing_barrier.wait(); + + // Track which pages we write to + let mut pages_written = HashSet::new(); + let mut total_writes = 0; + + // Perform random memory updates + for _update_id in 0..UPDATES_PER_THREAD { + // Generate random page index + let page_idx = rng.random_range(0..num_pages); + let page_offset = page_idx * PAGE_SIZE; + + // Generate random offset within the page (avoid last byte to prevent overruns) + let within_page_offset = rng.random_range(0..(PAGE_SIZE - 1)); + let write_addr = addr + page_offset + within_page_offset; + + // Generate random value to write + let value = rng.random::(); + + // Write to memory to trigger dirty tracking + unsafe { + std::ptr::write_volatile(write_addr as *mut u8, value); + } + + // Track this page as written to (HashSet handles duplicates) + pages_written.insert(page_idx); + total_writes += 1; + } + + // Final verification: check that ALL pages we wrote to are marked as dirty + let final_dirty_pages = tracker.get_dirty_pages().unwrap(); + + // Check that every page we wrote to is marked as dirty + for &page_idx in &pages_written { + assert!( + final_dirty_pages.contains(&page_idx), + "Thread {}: Page {} was written but not marked dirty. Pages written: {:?}, Pages dirty: {:?}", + thread_id, + page_idx, + pages_written, + final_dirty_pages + ); + } + + // Verify that the number of unique dirty pages matches unique pages written + let dirty_pages_set: HashSet = final_dirty_pages.into_iter().collect(); + assert_eq!( + pages_written.len(), + dirty_pages_set.len(), + "Thread {}: Mismatch between unique pages written ({}) and unique dirty pages ({}). \ + Total writes: {}, Pages written: {:?}, Dirty pages: {:?}", + thread_id, + pages_written.len(), + dirty_pages_set.len(), + total_writes, + pages_written, + dirty_pages_set + ); + + // Verify that dirty pages don't contain extra pages we didn't write to + for &dirty_page in &dirty_pages_set { + assert!( + pages_written.contains(&dirty_page), + "Thread {}: Found dirty page {} that was not written to. Pages written: {:?}", + thread_id, + dirty_page, + pages_written + ); + } + + // Clean up + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + + (pages_written.len(), dirty_pages_set.len(), total_writes) + }); + + handles.push(handle); + } + + // Wait for all threads to complete and collect results + let mut total_unique_pages_written = 0; + let mut total_unique_dirty_pages = 0; + let mut total_write_operations = 0; + + for (thread_id, handle) in handles.into_iter().enumerate() { + let (unique_pages_written, unique_dirty_pages, write_operations) = handle + .join() + .unwrap_or_else(|_| panic!("Thread {} panicked", thread_id)); + + total_unique_pages_written += unique_pages_written; + total_unique_dirty_pages += unique_dirty_pages; + total_write_operations += write_operations; + } + + println!("Concurrent test completed:"); + println!(" {} threads", NUM_THREADS); + println!(" {} updates per thread", UPDATES_PER_THREAD); + println!(" {} total write operations", total_write_operations); + println!( + " {} total unique pages written", + total_unique_pages_written + ); + println!( + " {} total unique dirty pages detected", + total_unique_dirty_pages + ); + + // Verify that we detected the expected number of dirty pages + assert!( + total_unique_dirty_pages > 0, + "No dirty pages detected across all threads" + ); + assert_eq!( + total_unique_pages_written, total_unique_dirty_pages, + "Mismatch between unique pages written and unique dirty pages detected" + ); + + // The total write operations should normally be much higher than unique pages (due to multiple writes to same pages) + assert!( + total_write_operations >= total_unique_pages_written, + "Total write operations ({}) should be >= unique pages written ({})", + total_write_operations, + total_unique_pages_written + ); + } + + #[test] + fn test_tracker_contains_address() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 2); + let addr = memory_ptr as usize; + + let tracker = create_test_tracker(addr, memory_size).unwrap(); + + // Test address checking (internal method) + assert!(tracker.contains_address(addr)); + assert!(tracker.contains_address(addr + 100)); + assert!(tracker.contains_address(addr + memory_size - 1)); + assert!(!tracker.contains_address(addr - 1)); + assert!(!tracker.contains_address(addr + memory_size)); + + // Explicitly drop tracker before freeing memory + drop(tracker); + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_write_protection_active() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE); + let addr = memory_ptr as usize; + + let tracker = create_test_tracker(addr, memory_size).unwrap(); + + // Memory should be write-protected initially + // Writing should trigger SIGSEGV (which gets handled by our signal handler) + unsafe { + std::ptr::write_volatile((addr + 100) as *mut u8, 42); + } + + // If we get here without crashing, the signal handler worked + + // Explicitly drop tracker before freeing memory + drop(tracker); + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_stress_multiple_writes() { + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 5); + let addr = memory_ptr as usize; + + let tracker = create_test_tracker(addr, memory_size).unwrap(); + + // Write to many different pages and offsets + for page in 0..5 { + for offset in [0, 100, 500, 1000, PAGE_SIZE - 1] { + let write_addr = addr + (page * PAGE_SIZE) + offset; + if write_addr < addr + memory_size { + unsafe { + std::ptr::write_volatile(write_addr as *mut u8, (page + offset) as u8); + } + } + } + } + + let dirty_pages = tracker.get_dirty_pages().unwrap(); + println!("Stress test dirty pages: {:?}", dirty_pages); + + // Verify all page indices are valid + for &page_idx in &dirty_pages { + assert!(page_idx < 5, "Invalid page index: {}", page_idx); + } + + // tracker is already dropped by get_dirty_pages() call above + + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } + + #[test] + fn test_pid_tracking_and_isolation() { + let (memory_ptr1, memory_size1) = create_aligned_memory(PAGE_SIZE * 4); + let (memory_ptr2, memory_size2) = create_aligned_memory(PAGE_SIZE * 4); + let addr1 = memory_ptr1 as usize; + let addr2 = memory_ptr2 as usize; + + // Create two trackers + let tracker1 = create_test_tracker(addr1, memory_size1).unwrap(); + let tracker2 = create_test_tracker(addr2, memory_size2).unwrap(); + + let current_pid = std::process::id(); + + // Verify that tracker data contains the correct PID + let trackers = get_trackers(); + let tracker1_data = trackers.get(&tracker1.id).unwrap(); + let tracker2_data = trackers.get(&tracker2.id).unwrap(); + + assert_eq!( + tracker1_data.val().pid, + current_pid, + "Tracker 1 should store the current process ID" + ); + assert_eq!( + tracker2_data.val().pid, + current_pid, + "Tracker 2 should store the current process ID" + ); + + // Explicitly drop trackers before freeing memory + drop(tracker1); + drop(tracker2); + + // Clean up + unsafe { + free_aligned_memory(memory_ptr1, memory_size1); + free_aligned_memory(memory_ptr2, memory_size2); + } + } + + #[test] + fn test_overlap_detection_with_same_virtual_addresses() { + // This test verifies that overlap detection is now scoped per process + // In a real multi-process scenario, different processes could have the same + // virtual addresses that map to different physical memory, so overlaps + // should only be checked within the same process. + + let (memory_ptr, memory_size) = create_aligned_memory(PAGE_SIZE * 4); + let addr = memory_ptr as usize; + + // Create a tracker for this address range + let tracker1 = create_test_tracker(addr, memory_size).unwrap(); + + // Verify the tracker is storing the current PID + let current_pid = std::process::id(); + let trackers = get_trackers(); + let tracker_data = trackers.get(&tracker1.id).unwrap(); + assert_eq!(tracker_data.val().pid, current_pid); + + // Creating an overlapping tracker with the same PID should fail + let overlap_result = create_test_tracker(addr + PAGE_SIZE, PAGE_SIZE * 2); + assert!( + overlap_result.is_err(), + "Creating overlapping tracker in same process should fail" + ); + + // Clean up + drop(tracker1); + unsafe { + free_aligned_memory(memory_ptr, memory_size); + } + } +} diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 90cb76573..e92f6bca3 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -24,6 +24,7 @@ use hyperlight_common::flatbuffer_wrappers::function_types::ReturnValue; use hyperlight_common::flatbuffer_wrappers::guest_error::GuestError; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::flatbuffer_wrappers::host_function_details::HostFunctionDetails; +use hyperlight_common::mem::PAGES_IN_BLOCK; use tracing::{Span, instrument}; use super::exe::ExeInfo; @@ -33,8 +34,9 @@ use super::memory_region::{DEFAULT_GUEST_BLOB_MEM_FLAGS, MemoryRegion, MemoryReg use super::ptr::{GuestPtr, RawPtr}; use super::ptr_offset::Offset; use super::shared_mem::{ExclusiveSharedMemory, GuestSharedMemory, HostSharedMemory, SharedMemory}; -use super::shared_mem_snapshot::SharedMemorySnapshot; -use crate::HyperlightError::NoMemorySnapshot; +use super::shared_memory_snapshot_manager::SharedMemorySnapshotManager; +use crate::mem::bitmap::{bitmap_union, new_page_bitmap}; +use crate::mem::dirty_page_tracking::DirtyPageTracker; use crate::sandbox::SandboxConfiguration; use crate::sandbox::uninitialized::GuestBlob; use crate::{Result, log_then_return, new_error}; @@ -75,9 +77,8 @@ pub(crate) struct SandboxMemoryManager { pub(crate) entrypoint_offset: Offset, /// How many memory regions were mapped after sandbox creation pub(crate) mapped_rgns: u64, - /// A vector of memory snapshots that can be used to save and restore the state of the memory - /// This is used by the Rust Sandbox implementation (rather than the mem_snapshot field above which only exists to support current C API) - snapshots: Arc>>, + /// Shared memory snapshots that can be used to save and restore the state of the memory + snapshot_manager: Arc>>, } impl SandboxMemoryManager @@ -98,7 +99,7 @@ where load_addr, entrypoint_offset, mapped_rgns: 0, - snapshots: Arc::new(Mutex::new(Vec::new())), + snapshot_manager: Arc::new(Mutex::new(None)), } } @@ -265,14 +266,40 @@ where } } - /// this function will create a memory snapshot and push it onto the stack of snapshots - /// It should be used when you want to save the state of the memory, for example, when evolving a sandbox to a new state - pub(crate) fn push_state(&mut self) -> Result<()> { - let snapshot = SharedMemorySnapshot::new(&mut self.shared_mem, self.mapped_rgns)?; - self.snapshots + /// this function will create an initial snapshot and then create the SnapshotManager + pub(crate) fn create_initial_snapshot( + &mut self, + vm_dirty_bitmap: &[u64], + host_dirty_page_idx: &[usize], + layout: &SandboxMemoryLayout, + ) -> Result<()> { + let mut existing_snapshot_manager = self + .snapshot_manager .try_lock() - .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? - .push(snapshot); + .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; + + if existing_snapshot_manager.is_some() { + log_then_return!("Snapshot manager already initialized, not creating a new one"); + } + + // covert vec of page indices to bitmap + let mut res = new_page_bitmap(self.shared_mem.raw_mem_size(), false)?; + for page_idx in host_dirty_page_idx { + let block_idx = page_idx / PAGES_IN_BLOCK; + let bit_idx = page_idx % PAGES_IN_BLOCK; + res[block_idx] |= 1 << bit_idx; + } + + // merge the host dirty page map into the dirty bitmap + let merged = bitmap_union(&res, vm_dirty_bitmap); + + let snapshot_manager = SharedMemorySnapshotManager::new( + &mut self.shared_mem, + &merged, + layout, + self.mapped_rgns, + )?; + existing_snapshot_manager.replace(snapshot_manager); Ok(()) } @@ -280,39 +307,59 @@ where /// off the stack /// It should be used when you want to restore the state of the memory to a previous state but still want to /// retain that state, for example after calling a function in the guest - /// - /// Returns the number of memory regions mapped into the sandbox - /// that need to be unmapped in order for the restore to be - /// completed. - pub(crate) fn restore_state_from_last_snapshot(&mut self) -> Result { - let mut snapshots = self - .snapshots + pub(crate) fn restore_state_from_last_snapshot(&mut self, dirty_bitmap: &[u64]) -> Result { + let mut snapshot_manager = self + .snapshot_manager .try_lock() .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; - let last = snapshots.last_mut(); - if last.is_none() { - log_then_return!(NoMemorySnapshot); + + match snapshot_manager.as_mut() { + None => { + log_then_return!("Snapshot manager not initialized"); + } + Some(snapshot_manager) => { + snapshot_manager.restore_from_snapshot(&mut self.shared_mem, dirty_bitmap) + } } - #[allow(clippy::unwrap_used)] // We know that last is not None because we checked it above - let snapshot = last.unwrap(); - let old_rgns = self.mapped_rgns; - self.mapped_rgns = snapshot.restore_from_snapshot(&mut self.shared_mem)?; - Ok(old_rgns - self.mapped_rgns) } /// this function pops the last snapshot off the stack and restores the memory to the previous state /// It should be used when you want to restore the state of the memory to a previous state and do not need to retain that state /// for example when devolving a sandbox to a previous state. - pub(crate) fn pop_and_restore_state_from_snapshot(&mut self) -> Result { - let last = self - .snapshots + pub(crate) fn pop_and_restore_state_from_snapshot( + &mut self, + dirty_bitmap: &[u64], + ) -> Result { + let mut snapshot_manager = self + .snapshot_manager .try_lock() - .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? - .pop(); - if last.is_none() { - log_then_return!(NoMemorySnapshot); + .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; + + match snapshot_manager.as_mut() { + None => { + log_then_return!("Snapshot manager not initialized"); + } + Some(snapshot_manager) => snapshot_manager + .pop_and_restore_state_from_snapshot(&mut self.shared_mem, dirty_bitmap), + } + } + + pub(crate) fn push_state(&mut self, dirty_bitmap: &[u64]) -> Result<()> { + let mut snapshot_manager = self + .snapshot_manager + .try_lock() + .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; + + match snapshot_manager.as_mut() { + None => { + log_then_return!("Snapshot manager not initialized"); + } + Some(snapshot_manager) => snapshot_manager.create_new_snapshot( + &mut self.shared_mem, + dirty_bitmap, + self.mapped_rgns, + ), } - self.restore_state_from_last_snapshot() } /// Sets `addr` to the correct offset in the memory referenced by @@ -347,7 +394,7 @@ impl SandboxMemoryManager { cfg: SandboxConfiguration, exe_info: &mut ExeInfo, guest_blob: Option<&GuestBlob>, - ) -> Result { + ) -> Result<(Self, DirtyPageTracker)> { let guest_blob_size = guest_blob.map(|b| b.data.len()).unwrap_or(0); let guest_blob_mem_flags = guest_blob.map(|b| b.permissions); @@ -360,6 +407,7 @@ impl SandboxMemoryManager { guest_blob_mem_flags, )?; let mut shared_mem = ExclusiveSharedMemory::new(layout.get_memory_size()?)?; + let tracker = shared_mem.start_tracking_dirty_pages()?; let load_addr: RawPtr = RawPtr::try_from(layout.get_guest_code_address())?; @@ -378,7 +426,10 @@ impl SandboxMemoryManager { &mut shared_mem.as_mut_slice()[layout.get_guest_code_offset()..], )?; - Ok(Self::new(layout, shared_mem, load_addr, entrypoint_offset)) + Ok(( + Self::new(layout, shared_mem, load_addr, entrypoint_offset), + tracker, + )) } /// Writes host function details to memory @@ -414,6 +465,7 @@ impl SandboxMemoryManager { host_function_call_buffer.as_slice(), self.layout.host_function_definitions_buffer_offset, )?; + Ok(()) } @@ -440,7 +492,7 @@ impl SandboxMemoryManager { load_addr: self.load_addr.clone(), entrypoint_offset: self.entrypoint_offset, mapped_rgns: 0, - snapshots: Arc::new(Mutex::new(Vec::new())), + snapshot_manager: Arc::new(Mutex::new(None)), }, SandboxMemoryManager { shared_mem: gshm, @@ -448,7 +500,7 @@ impl SandboxMemoryManager { load_addr: self.load_addr.clone(), entrypoint_offset: self.entrypoint_offset, mapped_rgns: 0, - snapshots: Arc::new(Mutex::new(Vec::new())), + snapshot_manager: Arc::new(Mutex::new(None)), }, ) } diff --git a/src/hyperlight_host/src/mem/mod.rs b/src/hyperlight_host/src/mem/mod.rs index 1bcc03eae..a7c36e1ea 100644 --- a/src/hyperlight_host/src/mem/mod.rs +++ b/src/hyperlight_host/src/mem/mod.rs @@ -14,17 +14,25 @@ See the License for the specific language governing permissions and limitations under the License. */ +/// Various helper functions for working with bitmaps +pub(crate) mod bitmap; +/// a module for tracking dirty pages in the host. +pub(crate) mod dirty_page_tracking; /// A simple ELF loader pub(crate) mod elf; /// A generic wrapper for executable files (PE, ELF, etc) pub(crate) mod exe; /// Functionality to establish a sandbox's memory layout. pub mod layout; +#[cfg(target_os = "linux")] +mod linux_dirty_page_tracker; /// memory regions to be mapped inside a vm pub mod memory_region; /// Functionality that wraps a `SandboxMemoryLayout` and a /// `SandboxMemoryConfig` to mutate a sandbox's memory as necessary. pub mod mgr; +/// A compact snapshot representation for memory pages +pub(crate) mod page_snapshot; /// Structures to represent pointers into guest and host memory pub mod ptr; /// Structures to represent memory address spaces into which pointers @@ -35,9 +43,10 @@ pub mod ptr_offset; /// A wrapper around unsafe functionality to create and initialize /// a memory region for a guest running in a sandbox. pub mod shared_mem; -/// A wrapper around a `SharedMemory` and a snapshot in time -/// of the memory therein -pub mod shared_mem_snapshot; /// Utilities for writing shared memory tests #[cfg(test)] pub(crate) mod shared_mem_tests; +/// A wrapper around a `SharedMemory` to manage snapshots of the memory +pub mod shared_memory_snapshot_manager; +#[cfg(target_os = "windows")] +mod windows_dirty_page_tracker; diff --git a/src/hyperlight_host/src/mem/page_snapshot.rs b/src/hyperlight_host/src/mem/page_snapshot.rs new file mode 100644 index 000000000..3c79013f6 --- /dev/null +++ b/src/hyperlight_host/src/mem/page_snapshot.rs @@ -0,0 +1,97 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use std::collections::HashMap; + +use hyperlight_common::mem::PAGE_SIZE_USIZE; + +/// A compact snapshot representation that stores pages in a contiguous buffer +/// with an index for efficient lookup. +/// +/// This struct is designed to efficiently store and retrieve memory snapshots +/// by using a contiguous buffer for all page data combined with a HashMap index +/// for page lookups. This approach reduces memory overhead +/// compared to storing pages individually. +/// +/// # Clone Derivation +/// +/// This struct derives `Clone` because it's stored in `Vec` within +/// `SharedMemorySnapshotManager`, which itself derives `Clone`. +#[derive(Clone)] +pub(super) struct PageSnapshot { + /// Maps page numbers to their offset within the buffer (in page units) + page_index: HashMap, // page_number -> buffer_offset_in_pages + /// Contiguous buffer containing all the page data + buffer: Vec, + /// How many non-main-RAM regions were mapped when this snapshot was taken? + mapped_rgns: u64, +} + +impl PageSnapshot { + /// Create a new empty snapshot + pub(super) fn new() -> Self { + Self { + page_index: HashMap::new(), + buffer: Vec::new(), + mapped_rgns: 0, + } + } + + /// Create a snapshot from a list of page numbers with pre-allocated buffer + pub(super) fn with_pages_and_buffer( + page_numbers: Vec, + buffer: Vec, + mapped_rgns: u64, + ) -> Self { + let page_count = page_numbers.len(); + let mut page_index = HashMap::with_capacity(page_count); + + // Map each page number to its offset in the buffer + for (buffer_offset, page_num) in page_numbers.into_iter().enumerate() { + page_index.insert(page_num, buffer_offset); + } + + Self { + page_index, + buffer, + mapped_rgns, + } + } + + /// Get page data by page number, returns None if page is not in snapshot + pub(super) fn get_page(&self, page_num: usize) -> Option<&[u8]> { + self.page_index.get(&page_num).map(|&buffer_offset| { + let start = buffer_offset * PAGE_SIZE_USIZE; + let end = start + PAGE_SIZE_USIZE; + &self.buffer[start..end] + }) + } + + /// Get an iterator over all page numbers in this snapshot + pub(super) fn page_numbers(&self) -> impl Iterator + '_ { + self.page_index.keys().copied() + } + + /// Get the maximum page number in this snapshot, or None if empty + pub(super) fn max_page(&self) -> Option { + self.page_index.keys().max().copied() + } + + /// Get the number of mapped regions when this snapshot was taken + pub(super) fn mapped_rgns(&self) -> u64 { + self.mapped_rgns + } +} diff --git a/src/hyperlight_host/src/mem/shared_mem.rs b/src/hyperlight_host/src/mem/shared_mem.rs index 50c809f44..c7b5a1e0a 100644 --- a/src/hyperlight_host/src/mem/shared_mem.rs +++ b/src/hyperlight_host/src/mem/shared_mem.rs @@ -39,6 +39,7 @@ use windows::core::PCSTR; use crate::HyperlightError::MemoryAllocationFailed; #[cfg(target_os = "windows")] use crate::HyperlightError::{MemoryRequestTooBig, WindowsAPIError}; +use crate::mem::dirty_page_tracking::{DirtyPageTracker, DirtyPageTracking}; use crate::{Result, log_then_return, new_error}; /// Makes sure that the given `offset` and `size` are within the bounds of the memory with size `mem_size`. @@ -387,6 +388,19 @@ impl ExclusiveSharedMemory { }) } + /// Starts tracking dirty pages in the shared memory region. + pub(super) fn start_tracking_dirty_pages(&self) -> Result { + DirtyPageTracker::new(self) + } + + /// Stop tracking dirty pages in the shared memory region. + pub(crate) fn stop_tracking_dirty_pages( + &self, + tracker: DirtyPageTracker, + ) -> Result> { + tracker.get_dirty_pages() + } + /// Create a new region of shared memory with the given minimum /// size in bytes. The region will be surrounded by guard pages. /// @@ -613,6 +627,15 @@ impl ExclusiveSharedMemory { Ok(()) } + /// Copies bytes to slice from self starting at offset + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub fn copy_to_slice(&self, slice: &mut [u8], offset: usize) -> Result<()> { + let data = self.as_slice(); + bounds_check!(offset, slice.len(), data.len()); + slice.copy_from_slice(&data[offset..offset + slice.len()]); + Ok(()) + } + /// Return the address of memory at an offset to this `SharedMemory` checking /// that the memory is within the bounds of the `SharedMemory`. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] @@ -621,6 +644,15 @@ impl ExclusiveSharedMemory { Ok(self.base_addr() + offset) } + /// Fill the memory in the range `[offset, offset + len)` with `value` + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub fn zero_fill(&mut self, offset: usize, len: usize) -> Result<()> { + bounds_check!(offset, len, self.mem_size()); + let data = self.as_mut_slice(); + data[offset..offset + len].fill(0); + Ok(()) + } + generate_reader!(read_u8, u8); generate_reader!(read_i8, i8); generate_reader!(read_u16, u16); @@ -678,6 +710,9 @@ pub trait SharedMemory { /// Return a readonly reference to the host mapping backing this SharedMemory fn region(&self) -> &HostMapping; + /// Return an Arc clone of the host mapping backing this SharedMemory + fn region_arc(&self) -> Arc; + /// Return the base address of the host mapping of this /// region. Following the general Rust philosophy, this does not /// need to be marked as `unsafe` because doing anything with this @@ -728,6 +763,11 @@ impl SharedMemory for ExclusiveSharedMemory { fn region(&self) -> &HostMapping { &self.region } + + fn region_arc(&self) -> Arc { + Arc::clone(&self.region) + } + fn with_exclusivity T>( &mut self, f: F, @@ -740,6 +780,11 @@ impl SharedMemory for GuestSharedMemory { fn region(&self) -> &HostMapping { &self.region } + + fn region_arc(&self) -> Arc { + Arc::clone(&self.region) + } + fn with_exclusivity T>( &mut self, f: F, @@ -982,6 +1027,11 @@ impl SharedMemory for HostSharedMemory { fn region(&self) -> &HostMapping { &self.region } + + fn region_arc(&self) -> Arc { + Arc::clone(&self.region) + } + fn with_exclusivity T>( &mut self, f: F, diff --git a/src/hyperlight_host/src/mem/shared_memory_snapshot_manager.rs b/src/hyperlight_host/src/mem/shared_memory_snapshot_manager.rs new file mode 100644 index 000000000..bf2f20e6c --- /dev/null +++ b/src/hyperlight_host/src/mem/shared_memory_snapshot_manager.rs @@ -0,0 +1,1314 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use hyperlight_common::mem::{PAGE_SIZE_USIZE, PAGES_IN_BLOCK}; +use tracing::{Span, instrument}; + +use super::page_snapshot::PageSnapshot; +use super::shared_mem::SharedMemory; +use crate::Result; +use crate::mem::bitmap::{bit_index_iterator, bitmap_union}; +use crate::mem::layout::SandboxMemoryLayout; + +/// A wrapper around a `SharedMemory` reference and a snapshot +/// of the memory therein +pub(super) struct SharedMemorySnapshotManager { + /// A vector of snapshots, each snapshot contains only the dirty pages in a compact format + /// The first snapshot is the initial state of the memory, subsequent snapshots after initialization + /// snapshots are deltas from the previous state (i.e. only the dirty pages are stored) + /// The initial snapshot is a delta from zeroing the memory on allocation + snapshots: Vec, + /// The offsets of the input and output data buffers in the memory layout are stored + /// this allows us to reset the input and output buffers to their initial state (i.e. zeroed) + /// each time we restore from a snapshot + input_data_size: usize, + output_data_size: usize, + output_data_buffer_offset: usize, + input_data_buffer_offset: usize, +} + +impl SharedMemorySnapshotManager { + /// Take a snapshot of the memory in `shared_mem`, then create a new + /// instance of `Self` with the snapshot stored therein. + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub(super) fn new( + shared_mem: &mut S, + dirty_page_map: &[u64], + layout: &SandboxMemoryLayout, + mapped_rgns: u64, + ) -> Result { + // Build a snapshot of memory from the dirty_page_map + + let diff = + Self::build_snapshot_from_dirty_page_map(shared_mem, dirty_page_map, mapped_rgns)?; + + // Get the input output buffer details from the layout so that they can be reset to their initial state + let input_data_size_offset = layout.get_input_data_size_offset(); + let output_data_size_offset = layout.get_output_data_size_offset(); + let output_data_buffer_offset = layout.get_output_data_pointer_offset(); + let input_data_buffer_offset = layout.get_input_data_pointer_offset(); + + // Read the input and output data sizes and pointers from memory + let ( + input_data_size, + output_data_size, + output_data_buffer_offset, + input_data_buffer_offset, + ) = shared_mem.with_exclusivity(|e| -> Result<(usize, usize, usize, usize)> { + Ok(( + e.read_usize(input_data_size_offset)?, + e.read_usize(output_data_size_offset)?, + e.read_usize(output_data_buffer_offset)?, + e.read_usize(input_data_buffer_offset)?, + )) + })??; + + Ok(Self { + snapshots: vec![diff], + input_data_size, + output_data_size, + output_data_buffer_offset, + input_data_buffer_offset, + }) + } + + fn build_snapshot_from_dirty_page_map( + shared_mem: &mut S, + dirty_page_map: &[u64], + mapped_rgns: u64, + ) -> Result { + // If there is no dirty page map, return an empty snapshot + if dirty_page_map.is_empty() { + return Ok(PageSnapshot::new()); + } + + // Should not happen, but just in case + if dirty_page_map.is_empty() { + return Ok(PageSnapshot::new()); + } + + let mut dirty_pages: Vec = bit_index_iterator(dirty_page_map).collect(); + + // Pre-allocate buffer for all pages + let page_count = dirty_pages.len(); + let total_size = page_count * PAGE_SIZE_USIZE; + let mut buffer = vec![0u8; total_size]; + + // if the total size is equal to the shared memory size, we can optimize the copy + if total_size == shared_mem.mem_size() { + // Copy the entire memory region in one go + shared_mem.with_exclusivity(|e| e.copy_to_slice(&mut buffer, 0))??; + } else { + // Sort pages for deterministic ordering and to enable consecutive page optimization + dirty_pages.sort_unstable(); + + let mut buffer_offset = 0; + let mut i = 0; + + while i < dirty_pages.len() { + let start_page = dirty_pages[i]; + let mut consecutive_count = 1; + + // Find consecutive pages + while i + consecutive_count < dirty_pages.len() + && dirty_pages[i + consecutive_count] == start_page + consecutive_count + { + consecutive_count += 1; + } + + // Calculate memory positions + let memory_offset = start_page * PAGE_SIZE_USIZE; + let copy_size = consecutive_count * PAGE_SIZE_USIZE; + let buffer_end = buffer_offset + copy_size; + + // Single copy operation for consecutive pages directly into final buffer + shared_mem.with_exclusivity(|e| { + e.copy_to_slice(&mut buffer[buffer_offset..buffer_end], memory_offset) + })??; + // copy_operations += 1; + + buffer_offset += copy_size; + i += consecutive_count; + } + } + + // Create the snapshot with the pre-allocated buffer + let snapshot = PageSnapshot::with_pages_and_buffer(dirty_pages, buffer, mapped_rgns); + + Ok(snapshot) + } + + pub(super) fn create_new_snapshot( + &mut self, + shared_mem: &mut S, + dirty_page_map: &[u64], + mapped_rgns: u64, + ) -> Result<()> { + let snapshot = + Self::build_snapshot_from_dirty_page_map(shared_mem, dirty_page_map, mapped_rgns)?; + self.snapshots.push(snapshot); + Ok(()) + } + + /// Copy the memory from the internally-stored memory snapshot + /// into the internally-stored `SharedMemory` + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub(super) fn restore_from_snapshot( + &mut self, + shared_mem: &mut S, + dirty_bitmap: &[u64], + ) -> Result { + // check the each index in the dirty bitmap and restore only the corresponding pages from the snapshots vector + // starting at the last snapshot look for the page in each snapshot if it exists and restore it + // if it does not exist set the page to zero + if self.snapshots.is_empty() { + return Err(crate::HyperlightError::NoMemorySnapshot); + } + + // Collect dirty pages and sort them for consecutive page optimization + let mut dirty_pages: Vec = bit_index_iterator(dirty_bitmap).collect(); + dirty_pages.sort_unstable(); + + let mut i = 0; + while i < dirty_pages.len() { + let start_page = dirty_pages[i]; + let mut consecutive_count = 1; + + // Find consecutive pages + while i + consecutive_count < dirty_pages.len() + && dirty_pages[i + consecutive_count] == start_page + consecutive_count + { + consecutive_count += 1; + } + + // Build buffer for consecutive pages + let mut buffer = vec![0u8; consecutive_count * PAGE_SIZE_USIZE]; + let mut buffer_offset = 0; + + for page_idx in 0..consecutive_count { + let page = start_page + page_idx; + + // Check for the page in every snapshot starting from the last one + for snapshot in self.snapshots.iter().rev() { + if let Some(data) = snapshot.get_page(page) { + buffer[buffer_offset..buffer_offset + PAGE_SIZE_USIZE] + .copy_from_slice(data); + break; + } + } + + buffer_offset += PAGE_SIZE_USIZE; + + // If the page was not found in any snapshot, it will be now be zero in the buffer as we skip over it above and didnt write any data + // This is the correct state as the page was not dirty in any snapshot which means it should be zeroed (the initial state) + } + + // Single copy operation for all consecutive pages + let memory_offset = start_page * PAGE_SIZE_USIZE; + shared_mem.with_exclusivity(|e| e.copy_from_slice(&buffer, memory_offset))??; + + i += consecutive_count; + } + // Reset input/output buffers these need to set to their initial state each time a snapshot is restored to clear any previous io/data that may be in the buffers + shared_mem.with_exclusivity(|e| { + e.zero_fill(self.input_data_buffer_offset, self.input_data_size)?; + e.zero_fill(self.output_data_buffer_offset, self.output_data_size)?; + e.write_u64( + self.input_data_buffer_offset, + SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, + )?; + e.write_u64( + self.output_data_buffer_offset, + SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, + ) + })??; + + #[allow(clippy::unwrap_used)] + Ok(self.snapshots.last().unwrap().mapped_rgns()) + } + + pub(super) fn pop_and_restore_state_from_snapshot( + &mut self, + shared_mem: &mut S, + dirty_bitmap: &[u64], + ) -> Result { + // Check that there is a snapshot to restore from + if self.snapshots.is_empty() { + return Err(crate::HyperlightError::NoMemorySnapshot); + } + // Get the last snapshot index + let last_snapshot_index = self.snapshots.len() - 1; + let last_snapshot_bitmap = self.get_bitmap_from_snapshot(last_snapshot_index); + // merge the last snapshot bitmap with the dirty bitmap + let merged_bitmap = bitmap_union(&last_snapshot_bitmap, dirty_bitmap); + + // drop the last snapshot then restore the state from the merged bitmap + if self.snapshots.pop().is_none() { + return Err(crate::HyperlightError::NoMemorySnapshot); + } + + // restore the state from the last snapshot + self.restore_from_snapshot(shared_mem, &merged_bitmap) + } + + fn get_bitmap_from_snapshot(&self, snapshot_index: usize) -> Vec { + // Get the snapshot at the given index + if snapshot_index < self.snapshots.len() { + let snapshot = &self.snapshots[snapshot_index]; + // Create a bitmap from the snapshot + let max_page = snapshot.max_page().unwrap_or_default(); + let num_blocks = max_page.div_ceil(PAGES_IN_BLOCK); + let mut bitmap = vec![0u64; num_blocks]; + for page in snapshot.page_numbers() { + let block = page / PAGES_IN_BLOCK; + let offset = page % PAGES_IN_BLOCK; + if block < bitmap.len() { + bitmap[block] |= 1 << offset; + } + } + bitmap + } else { + vec![] + } + } +} + +#[cfg(test)] +mod tests { + use hyperlight_common::mem::PAGE_SIZE_USIZE; + + use super::super::layout::SandboxMemoryLayout; + use crate::mem::bitmap::new_page_bitmap; + use crate::mem::shared_mem::{ExclusiveSharedMemory, SharedMemory}; + use crate::sandbox::SandboxConfiguration; + + fn create_test_layout() -> SandboxMemoryLayout { + let cfg = SandboxConfiguration::default(); + // Create a layout with large init_data area for testing (64KB for plenty of test pages) + let init_data_size = 64 * 1024; // 64KB = 16 pages of 4KB each + SandboxMemoryLayout::new(cfg, 4096, 16384, 16384, init_data_size, None).unwrap() + } + + fn create_test_shared_memory_with_layout( + layout: &SandboxMemoryLayout, + ) -> ExclusiveSharedMemory { + let memory_size = layout.get_memory_size().unwrap(); + let mut shared_mem = ExclusiveSharedMemory::new(memory_size).unwrap(); + + // Initialize the memory with the full layout to ensure it's properly set up + layout + .write( + &mut shared_mem, + SandboxMemoryLayout::BASE_ADDRESS, + memory_size, + ) + .unwrap(); + + shared_mem + } + + /// Get safe memory area for testing - uses init_data area which is safe to modify + fn get_safe_test_area( + layout: &SandboxMemoryLayout, + shared_mem: &mut ExclusiveSharedMemory, + ) -> (usize, usize) { + // The init_data area is positioned after the guest stack in the memory layout + // We can safely use this area for testing as it's designed for initialization data + // Read the actual init_data buffer offset and size from memory + let init_data_size_offset = layout.get_init_data_size_offset(); + let init_data_pointer_offset = layout.get_init_data_pointer_offset(); + + let (init_data_size, init_data_buffer_offset) = shared_mem + .with_exclusivity(|e| -> crate::Result<(usize, usize)> { + Ok(( + e.read_usize(init_data_size_offset)?, + e.read_usize(init_data_pointer_offset)?, + )) + }) + .unwrap() + .unwrap(); + + (init_data_buffer_offset, init_data_size) + } + + #[test] + fn test_single_snapshot_restore() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Use a safe page well within the init_data area + let safe_offset = init_data_offset + PAGE_SIZE_USIZE; // Skip first page for extra safety + + // Ensure we have enough space for testing + assert!( + init_data_size >= 2 * PAGE_SIZE_USIZE, + "Init data area too small for testing: {} bytes", + init_data_size + ); + assert!( + safe_offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Safe offset exceeds init_data bounds" + ); + + // Start tracking dirty pages + let tracker = shared_mem.start_tracking_dirty_pages().unwrap(); + + // Initial data - only initialize safe page, leave other pages as zero + let initial_data = vec![0xAA; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&initial_data, safe_offset) + .unwrap(); + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.stop_tracking_dirty_pages(tracker).unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + // Create snapshot + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &dirty_pages, &layout, 0) + .unwrap(); + + // Modify memory + let modified_data = vec![0xBB; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&modified_data, safe_offset) + .unwrap(); + + // Verify modification + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, safe_offset) + .unwrap(); + assert_eq!(current_data, modified_data); + + // Restore from snapshot + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + + // Verify restoration + let mut restored_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data, safe_offset) + .unwrap(); + assert_eq!(restored_data, initial_data); + } + + #[test] + fn test_multiple_snapshots_and_restores() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Use a safe page well within the init_data area + let safe_offset = init_data_offset + PAGE_SIZE_USIZE; // Skip first page for extra safety + + // Ensure we have enough space for testing + assert!( + init_data_size >= 2 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + assert!( + safe_offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Safe offset exceeds init_data bounds" + ); + + // Start tracking dirty pages + let tracker = shared_mem.start_tracking_dirty_pages().unwrap(); + + // State 1: Initial state + let state1_data = vec![0x11; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&state1_data, safe_offset) + .unwrap(); + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.stop_tracking_dirty_pages(tracker).unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + // Create initial snapshot (State 1) + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &dirty_pages, &layout, 0) + .unwrap(); + + // State 2: Modify and create second snapshot + let tracker2 = shared_mem.start_tracking_dirty_pages().unwrap(); + let state2_data = vec![0x22; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&state2_data, safe_offset) + .unwrap(); + let dirty_pages_vec2 = shared_mem.stop_tracking_dirty_pages(tracker2).unwrap(); + + let mut dirty_pages2 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec2 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages2.len() { + dirty_pages2[block] |= 1 << bit; + } + } + + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages2, 0) + .unwrap(); + + // State 3: Modify again + let state3_data = vec![0x33; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&state3_data, safe_offset) + .unwrap(); + + // Verify we're in state 3 + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, safe_offset) + .unwrap(); + assert_eq!(current_data, state3_data); + + // Restore to state 2 (most recent snapshot) + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages2) + .unwrap(); + let mut restored_data_state2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data_state2, safe_offset) + .unwrap(); + assert_eq!(restored_data_state2, state2_data); + + // Pop state 2 and restore to state 1 + snapshot_manager + .pop_and_restore_state_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + let mut restored_data_state1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data_state1, safe_offset) + .unwrap(); + assert_eq!(restored_data_state1, state1_data); + } + + #[test] + fn test_multiple_pages_snapshot_restore() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Ensure we have enough space for 4 test pages + assert!( + init_data_size >= 6 * PAGE_SIZE_USIZE, + "Init data area too small for testing multiple pages" + ); + + // Use page offsets within the init_data area, skipping first page for safety + let base_page = (init_data_offset + PAGE_SIZE_USIZE) / PAGE_SIZE_USIZE; + let page_offsets = [base_page, base_page + 1, base_page + 2, base_page + 3]; + + let page_data = [ + vec![0xAA; PAGE_SIZE_USIZE], + vec![0xBB; PAGE_SIZE_USIZE], + vec![0xCC; PAGE_SIZE_USIZE], + vec![0xDD; PAGE_SIZE_USIZE], + ]; + + // Start tracking dirty pages + let tracker = shared_mem.start_tracking_dirty_pages().unwrap(); + + // Initialize data in init_data pages + for (i, &page_offset) in page_offsets.iter().enumerate() { + let offset = page_offset * PAGE_SIZE_USIZE; + assert!( + offset + PAGE_SIZE_USIZE <= shared_mem.mem_size(), + "Page offset {} exceeds memory bounds", + page_offset + ); + assert!( + offset >= init_data_offset + && offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Page offset {} is outside init_data bounds", + page_offset + ); + shared_mem.copy_from_slice(&page_data[i], offset).unwrap(); + } + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.stop_tracking_dirty_pages(tracker).unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + // Create snapshot + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &dirty_pages, &layout, 0) + .unwrap(); + + // Modify first and third pages + let modified_data = [vec![0x11; PAGE_SIZE_USIZE], vec![0x22; PAGE_SIZE_USIZE]]; + shared_mem + .copy_from_slice(&modified_data[0], page_offsets[0] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&modified_data[1], page_offsets[2] * PAGE_SIZE_USIZE) + .unwrap(); + + // Restore from snapshot + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + + // Verify restoration + for (i, &page_offset) in page_offsets.iter().enumerate() { + let mut restored_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data, page_offset * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + restored_data, page_data[i], + "Page {} should be restored to original data", + i + ); + } + } + + #[test] + fn test_sequential_modifications_with_snapshots() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Use safe page offsets within init_data area + let safe_offset1 = init_data_offset + PAGE_SIZE_USIZE; // Skip first page for safety + let safe_offset2 = init_data_offset + 2 * PAGE_SIZE_USIZE; + + // Ensure we have enough space for testing + assert!( + init_data_size >= 3 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + assert!( + safe_offset2 + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Safe offsets exceed init_data bounds" + ); + + // Start tracking dirty pages + let tracker1 = shared_mem.start_tracking_dirty_pages().unwrap(); + + // Cycle 1: Set initial data + let cycle1_page0 = (0..PAGE_SIZE_USIZE) + .map(|i| (i % 256) as u8) + .collect::>(); + let cycle1_page1 = vec![0x01; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&cycle1_page0, safe_offset1) + .unwrap(); + shared_mem + .copy_from_slice(&cycle1_page1, safe_offset2) + .unwrap(); + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec1 = shared_mem.stop_tracking_dirty_pages(tracker1).unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec1 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &dirty_pages, &layout, 0) + .unwrap(); + + // Cycle 2: Modify and snapshot + let tracker2 = shared_mem.start_tracking_dirty_pages().unwrap(); + let cycle2_page0 = vec![0x02; PAGE_SIZE_USIZE]; + let cycle2_page1 = (0..PAGE_SIZE_USIZE) + .map(|i| ((i + 100) % 256) as u8) + .collect::>(); + shared_mem + .copy_from_slice(&cycle2_page0, safe_offset1) + .unwrap(); + shared_mem + .copy_from_slice(&cycle2_page1, safe_offset2) + .unwrap(); + + let dirty_pages_vec2 = shared_mem.stop_tracking_dirty_pages(tracker2).unwrap(); + + let mut dirty_pages2 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec2 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages2.len() { + dirty_pages2[block] |= 1 << bit; + } + } + + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages2, 0) + .unwrap(); + + // Cycle 3: Modify again + let cycle3_page0 = vec![0x03; PAGE_SIZE_USIZE]; + let cycle3_page1 = vec![0x33; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&cycle3_page0, safe_offset1) + .unwrap(); + shared_mem + .copy_from_slice(&cycle3_page1, safe_offset2) + .unwrap(); + + // Verify current state (cycle 3) + let mut current_page0 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_page0, safe_offset1) + .unwrap(); + let mut current_page1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_page1, safe_offset2) + .unwrap(); + assert_eq!(current_page0, cycle3_page0); + assert_eq!(current_page1, cycle3_page1); + + // Restore to cycle 2 + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages2) + .unwrap(); + let mut restored_page0 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page0, safe_offset1) + .unwrap(); + let mut restored_page1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page1, safe_offset2) + .unwrap(); + assert_eq!(restored_page0, cycle2_page0); + assert_eq!(restored_page1, cycle2_page1); + + // Pop cycle 2 and restore to cycle 1 + snapshot_manager + .pop_and_restore_state_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + let mut restored_page0 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page0, safe_offset1) + .unwrap(); + let mut restored_page1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page1, safe_offset2) + .unwrap(); + assert_eq!(restored_page0, cycle1_page0); + assert_eq!(restored_page1, cycle1_page1); + } + + #[test] + fn test_restore_with_zero_pages() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Ensure we have enough space for testing + assert!( + init_data_size >= 3 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + + // Start tracking dirty pages + let tracker = shared_mem.start_tracking_dirty_pages().unwrap(); + + // Only initialize one page in the init_data area + let page1_offset = init_data_offset + PAGE_SIZE_USIZE; // Skip first page for safety + let page1_data = vec![0xFF; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&page1_data, page1_offset) + .unwrap(); + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.stop_tracking_dirty_pages(tracker).unwrap(); + + // Convert to bitmap format + let mut dirty_pages_snapshot = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages_snapshot.len() { + dirty_pages_snapshot[block] |= 1 << bit; + } + } + + let mut snapshot_manager = super::SharedMemorySnapshotManager::new( + &mut shared_mem, + &dirty_pages_snapshot, + &layout, + 0, + ) + .unwrap(); + + // Modify pages in init_data area + let page0_offset = init_data_offset; + let page2_offset = init_data_offset + 2 * PAGE_SIZE_USIZE; + + let modified_page0 = vec![0xAA; PAGE_SIZE_USIZE]; + let modified_page1 = vec![0xBB; PAGE_SIZE_USIZE]; + let modified_page2 = vec![0xCC; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&modified_page0, page0_offset) + .unwrap(); + shared_mem + .copy_from_slice(&modified_page1, page1_offset) + .unwrap(); + shared_mem + .copy_from_slice(&modified_page2, page2_offset) + .unwrap(); + + // Create dirty page map for all test pages + let mut dirty_pages_restore = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + let page0_idx = page0_offset / PAGE_SIZE_USIZE; + let page1_idx = page1_offset / PAGE_SIZE_USIZE; + let page2_idx = page2_offset / PAGE_SIZE_USIZE; + + // Mark all test pages as dirty for restore + for &page_idx in &[page0_idx, page1_idx, page2_idx] { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages_restore.len() { + dirty_pages_restore[block] |= 1 << bit; + } + } + + // Restore from snapshot + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages_restore) + .unwrap(); + + // Verify restoration + let mut restored_page0 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page0, page0_offset) + .unwrap(); + let mut restored_page1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page1, page1_offset) + .unwrap(); + let mut restored_page2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_page2, page2_offset) + .unwrap(); + + // Page 0 and 2 should be zeroed (not in snapshot), page 1 should be restored + assert_eq!(restored_page0, vec![0u8; PAGE_SIZE_USIZE]); + assert_eq!(restored_page1, page1_data); + assert_eq!(restored_page2, vec![0u8; PAGE_SIZE_USIZE]); + } + + #[test] + fn test_empty_snapshot_error() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + let memory_size = shared_mem.mem_size(); + + // Create snapshot manager with no snapshots + let mut snapshot_manager = super::SharedMemorySnapshotManager { + snapshots: vec![], + input_data_size: 0, + output_data_size: 0, + output_data_buffer_offset: 0, + input_data_buffer_offset: 0, + }; + + let dirty_pages = new_page_bitmap(memory_size, true).unwrap(); + + // Should return error when trying to restore from empty snapshots + let result = snapshot_manager.restore_from_snapshot(&mut shared_mem, &dirty_pages); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::HyperlightError::NoMemorySnapshot + )); + + // Should return error when trying to pop from empty snapshots + let result = + snapshot_manager.pop_and_restore_state_from_snapshot(&mut shared_mem, &dirty_pages); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::HyperlightError::NoMemorySnapshot + )); + } + + #[test] + fn test_complex_workflow_simulation() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Ensure we have enough space for 4 test pages + assert!( + init_data_size >= 6 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + + // Start tracking dirty pages + let tracker = shared_mem.start_tracking_dirty_pages().unwrap(); + + // Use the init_data area - this is safe and won't interfere with other layout structures + let base_page = (init_data_offset + PAGE_SIZE_USIZE) / PAGE_SIZE_USIZE; // Skip first page for safety + let page_offsets = [base_page, base_page + 1, base_page + 2, base_page + 3]; + + // Initialize memory with pattern in init_data area + for (i, &page_offset) in page_offsets.iter().enumerate() { + let data = vec![i as u8; PAGE_SIZE_USIZE]; + let offset = page_offset * PAGE_SIZE_USIZE; + assert!( + offset + PAGE_SIZE_USIZE <= shared_mem.mem_size(), + "Page offset {} exceeds memory bounds", + page_offset + ); + assert!( + offset >= init_data_offset + && offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Page offset {} is outside init_data bounds", + page_offset + ); + shared_mem.copy_from_slice(&data, offset).unwrap(); + } + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.stop_tracking_dirty_pages(tracker).unwrap(); + + // Convert to bitmap format + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + + // Create initial checkpoint + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &dirty_pages, &layout, 0) + .unwrap(); + + // Simulate function call 1: modify pages 0 and 2 + let tracker1 = shared_mem.start_tracking_dirty_pages().unwrap(); + let func1_page0 = vec![0x10; PAGE_SIZE_USIZE]; + let func1_page2 = vec![0x12; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&func1_page0, page_offsets[0] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&func1_page2, page_offsets[2] * PAGE_SIZE_USIZE) + .unwrap(); + + let dirty_pages_vec1 = shared_mem.stop_tracking_dirty_pages(tracker1).unwrap(); + + let mut dirty_pages1 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec1 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages1.len() { + dirty_pages1[block] |= 1 << bit; + } + } + + // Checkpoint after function 1 + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages1, 0) + .unwrap(); + + // Simulate function call 2: modify pages 1 and 3 + let tracker2 = shared_mem.start_tracking_dirty_pages().unwrap(); + let func2_page1 = vec![0x21; PAGE_SIZE_USIZE]; + let func2_page3 = vec![0x23; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&func2_page1, page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&func2_page3, page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + + let dirty_pages_vec2 = shared_mem.stop_tracking_dirty_pages(tracker2).unwrap(); + + let mut dirty_pages2 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for page_idx in dirty_pages_vec2 { + let block = page_idx / 64; + let bit = page_idx % 64; + if block < dirty_pages2.len() { + dirty_pages2[block] |= 1 << bit; + } + } + + // Checkpoint after function 2 + snapshot_manager + .create_new_snapshot(&mut shared_mem, &dirty_pages2, 0) + .unwrap(); + + // Simulate function call 3: modify all pages + for (i, &page_offset) in page_offsets.iter().enumerate() { + let data = vec![0x30 + i as u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_from_slice(&data, page_offset * PAGE_SIZE_USIZE) + .unwrap(); + } + + // Verify current state (after function 3) + for (i, &page_offset) in page_offsets.iter().enumerate() { + let mut current = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current, page_offset * PAGE_SIZE_USIZE) + .unwrap(); + let expected = vec![0x30 + i as u8; PAGE_SIZE_USIZE]; + assert_eq!(current, expected); + } + + // Create a bitmap that includes all pages that were modified in function 3 + let mut dirty_pages_all_func3 = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + for &page_offset in &page_offsets { + let block = page_offset / 64; + let bit = page_offset % 64; + if block < dirty_pages_all_func3.len() { + dirty_pages_all_func3[block] |= 1 << bit; + } + } + + // Rollback to after function 2 + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages_all_func3) + .unwrap(); + + // Verify state after function 2 + let mut page0_after_func2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page0_after_func2, page_offsets[0] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page0_after_func2, func1_page0); + + let mut page1_after_func2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page1_after_func2, page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page1_after_func2, func2_page1); + + let mut page2_after_func2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page2_after_func2, page_offsets[2] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page2_after_func2, func1_page2); + + let mut page3_after_func2 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page3_after_func2, page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page3_after_func2, func2_page3); + + // Rollback to after function 1 + // Need to create a bitmap that includes all pages that could have been modified + let mut combined_dirty_pages1 = dirty_pages.clone(); + for i in 0..combined_dirty_pages1.len().min(dirty_pages1.len()) { + combined_dirty_pages1[i] |= dirty_pages1[i]; + } + for i in 0..combined_dirty_pages1.len().min(dirty_pages2.len()) { + combined_dirty_pages1[i] |= dirty_pages2[i]; + } + + snapshot_manager + .pop_and_restore_state_from_snapshot(&mut shared_mem, &combined_dirty_pages1) + .unwrap(); + + // Verify state after function 1 + let mut page0_after_func1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page0_after_func1, page_offsets[0] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page0_after_func1, func1_page0); + + let mut page1_after_func1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page1_after_func1, page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page1_after_func1, vec![1u8; PAGE_SIZE_USIZE]); // Original + + let mut page2_after_func1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page2_after_func1, page_offsets[2] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page2_after_func1, func1_page2); + + let mut page3_after_func1 = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page3_after_func1, page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page3_after_func1, vec![3u8; PAGE_SIZE_USIZE]); // Original + + // Rollback to initial state + // Need to create a bitmap that includes all pages that could have been modified + let mut combined_dirty_pages_all = dirty_pages.clone(); + for i in 0..combined_dirty_pages_all.len().min(dirty_pages1.len()) { + combined_dirty_pages_all[i] |= dirty_pages1[i]; + } + for i in 0..combined_dirty_pages_all.len().min(dirty_pages2.len()) { + combined_dirty_pages_all[i] |= dirty_pages2[i]; + } + + snapshot_manager + .pop_and_restore_state_from_snapshot(&mut shared_mem, &combined_dirty_pages_all) + .unwrap(); + + // Verify initial state + for (i, &page_offset) in page_offsets.iter().enumerate() { + let mut current = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current, page_offset * PAGE_SIZE_USIZE) + .unwrap(); + let expected = vec![i as u8; PAGE_SIZE_USIZE]; + assert_eq!(current, expected); + } + } + + #[test] + fn test_unchanged_data_verification() { + let layout = create_test_layout(); + let mut shared_mem = create_test_shared_memory_with_layout(&layout); + + // Get safe init_data area for testing + let (init_data_offset, init_data_size) = get_safe_test_area(&layout, &mut shared_mem); + + // Ensure we have enough space for 6 test pages + assert!( + init_data_size >= 8 * PAGE_SIZE_USIZE, + "Init data area too small for testing" + ); + + // Start tracking dirty pages + let tracker = shared_mem.start_tracking_dirty_pages().unwrap(); + + // Initialize all pages with different patterns - use safe offsets within init_data area + let base_page = (init_data_offset + PAGE_SIZE_USIZE) / PAGE_SIZE_USIZE; // Skip first page for safety + let page_offsets = [ + base_page, + base_page + 1, + base_page + 2, + base_page + 3, + base_page + 4, + base_page + 5, + ]; + let initial_patterns = [ + vec![0xAA; PAGE_SIZE_USIZE], // Page 0 + vec![0xBB; PAGE_SIZE_USIZE], // Page 1 + vec![0xCC; PAGE_SIZE_USIZE], // Page 2 + vec![0xDD; PAGE_SIZE_USIZE], // Page 3 + vec![0xEE; PAGE_SIZE_USIZE], // Page 4 + vec![0xFF; PAGE_SIZE_USIZE], // Page 5 + ]; + + for (i, pattern) in initial_patterns.iter().enumerate() { + let offset = page_offsets[i] * PAGE_SIZE_USIZE; + assert!( + offset + PAGE_SIZE_USIZE <= shared_mem.mem_size(), + "Page offset {} exceeds memory bounds", + page_offsets[i] + ); + assert!( + offset >= init_data_offset + && offset + PAGE_SIZE_USIZE <= init_data_offset + init_data_size, + "Page offset {} is outside init_data bounds", + page_offsets[i] + ); + shared_mem.copy_from_slice(pattern, offset).unwrap(); + } + + // Stop tracking and get dirty pages bitmap + let dirty_pages_vec = shared_mem.stop_tracking_dirty_pages(tracker).unwrap(); + + // Convert to bitmap format - only track specific pages (1, 3, 5) + let mut dirty_pages = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + let tracked_pages = [1, 3, 5]; // Only track these pages for snapshot + for &tracked_page_idx in &tracked_pages { + let actual_page = page_offsets[tracked_page_idx]; + if dirty_pages_vec.contains(&actual_page) { + let block = actual_page / 64; + let bit = actual_page % 64; + if block < dirty_pages.len() { + dirty_pages[block] |= 1 << bit; + } + } + } + + // Create snapshot + let mut snapshot_manager = + super::SharedMemorySnapshotManager::new(&mut shared_mem, &dirty_pages, &layout, 0) + .unwrap(); + + // Modify only the dirty pages + let modified_patterns = [ + vec![0x11; PAGE_SIZE_USIZE], // Page 1 modified + vec![0x33; PAGE_SIZE_USIZE], // Page 3 modified + vec![0x55; PAGE_SIZE_USIZE], // Page 5 modified + ]; + + shared_mem + .copy_from_slice(&modified_patterns[0], page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&modified_patterns[1], page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&modified_patterns[2], page_offsets[5] * PAGE_SIZE_USIZE) + .unwrap(); + + // Verify that untracked pages (0, 2, 4) remain unchanged + let unchanged_pages = [0, 2, 4]; + for &page_idx in &unchanged_pages { + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, page_offsets[page_idx] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + current_data, initial_patterns[page_idx], + "Page {} should remain unchanged after modification", + page_idx + ); + } + + // Verify that tracked pages were modified + let changed_pages = [ + (1, &modified_patterns[0]), + (3, &modified_patterns[1]), + (5, &modified_patterns[2]), + ]; + for &(page_idx, expected) in &changed_pages { + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, page_offsets[page_idx] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + current_data, *expected, + "Page {} should be modified", + page_idx + ); + } + + // Restore from snapshot + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &dirty_pages) + .unwrap(); + + // Verify tracked pages are restored to their original state + for &page_idx in &tracked_pages { + let mut restored_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut restored_data, page_offsets[page_idx] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + restored_data, initial_patterns[page_idx], + "Page {} should be restored to initial pattern after snapshot restore", + page_idx + ); + } + + // Test partial dirty bitmap restoration + let mut partial_dirty = new_page_bitmap(shared_mem.mem_size(), false).unwrap(); + // Only mark page 1 as dirty for restoration + let page1_actual = page_offsets[1]; + let block = page1_actual / 64; + let bit = page1_actual % 64; + if block < partial_dirty.len() { + partial_dirty[block] |= 1 << bit; + } + + // Modify multiple pages again + shared_mem + .copy_from_slice(&modified_patterns[0], page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + shared_mem + .copy_from_slice(&modified_patterns[1], page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + + // Restore with partial dirty bitmap (only page 1) + snapshot_manager + .restore_from_snapshot(&mut shared_mem, &partial_dirty) + .unwrap(); + + // Verify page 1 is restored but page 3 remains modified + let mut page1_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page1_data, page_offsets[1] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!(page1_data, initial_patterns[1], "Page 1 should be restored"); + + let mut page3_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut page3_data, page_offsets[3] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + page3_data, modified_patterns[1], + "Page 3 should remain modified since it wasn't in restoration dirty bitmap" + ); + + // Verify all other pages remain in their expected state + for page_idx in [0, 2, 4, 5] { + let mut current_data = vec![0u8; PAGE_SIZE_USIZE]; + shared_mem + .copy_to_slice(&mut current_data, page_offsets[page_idx] * PAGE_SIZE_USIZE) + .unwrap(); + assert_eq!( + current_data, initial_patterns[page_idx], + "Page {} should remain in initial state", + page_idx + ); + } + } +} diff --git a/src/hyperlight_host/src/mem/windows_dirty_page_tracker.rs b/src/hyperlight_host/src/mem/windows_dirty_page_tracker.rs new file mode 100644 index 000000000..da0463262 --- /dev/null +++ b/src/hyperlight_host/src/mem/windows_dirty_page_tracker.rs @@ -0,0 +1,67 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use std::sync::Arc; + +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use tracing::{Span, instrument}; + +use super::dirty_page_tracking::DirtyPageTracking; +use super::shared_mem::{HostMapping, SharedMemory}; +use crate::Result; + +/// Windows implementation of dirty page tracking +pub struct WindowsDirtyPageTracker { + _base_addr: usize, + _size: usize, + num_pages: usize, + /// Keep a reference to the HostMapping to ensure memory lifetime + _mapping: Arc, +} + +// DirtyPageTracker should be Send because: +// 1. The Arc ensures the memory stays valid +// 2. The tracker handles synchronization properly +// 3. This is needed for threaded sandbox initialization +unsafe impl Send for WindowsDirtyPageTracker {} + +impl WindowsDirtyPageTracker { + /// Create a new Windows dirty page tracker + #[instrument(skip_all, parent = Span::current(), level = "Trace")] + pub fn new(shared_memory: &T) -> Result { + let mapping = shared_memory.region_arc(); + let base_addr = shared_memory.base_addr(); + let size = shared_memory.raw_mem_size(); + let num_pages = size.div_ceil(PAGE_SIZE_USIZE); + + Ok(Self { + _base_addr: base_addr, + _size: size, + num_pages, + _mapping: mapping, + }) + } +} + +impl DirtyPageTracking for WindowsDirtyPageTracker { + /// Returns a dirty page bitmap with all bits set for the memory size + /// This is a simplified implementation that marks all pages as dirty + /// until we implement actual dirty page tracking + fn get_dirty_pages(self) -> Result> { + // Return all page indices from 0 to num_pages-1 + Ok((0..self.num_pages).collect()) // exclude the guard pages + } +} diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 8df9d08ef..fc5c253c6 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -270,8 +270,9 @@ impl MultiUseSandbox { #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] pub(crate) fn restore_state(&mut self) -> Result<()> { let mem_mgr = self.mem_mgr.unwrap_mgr_mut(); - let rgns_to_unmap = mem_mgr.restore_state_from_last_snapshot()?; - unsafe { self.vm.unmap_regions(rgns_to_unmap)? }; + let dirty_pages = self.vm.get_and_clear_dirty_pages()?; + let rgns_to_umap = mem_mgr.restore_state_from_last_snapshot(&dirty_pages)?; + unsafe { self.vm.unmap_regions(rgns_to_umap)? }; Ok(()) } @@ -354,10 +355,11 @@ impl DevolvableSandbox) -> Result { + let dirty_pages = self.vm.get_and_clear_dirty_pages()?; let rgns_to_unmap = self .mem_mgr .unwrap_mgr_mut() - .pop_and_restore_state_from_snapshot()?; + .pop_and_restore_state_from_snapshot(&dirty_pages)?; unsafe { self.vm.unmap_regions(rgns_to_unmap)? }; Ok(self) } @@ -389,7 +391,8 @@ where let mut ctx = self.new_call_context(); transition_func.call(&mut ctx)?; let mut sbox = ctx.finish_no_reset(); - sbox.mem_mgr.unwrap_mgr_mut().push_state()?; + let vm_dirty_pages = sbox.vm.get_and_clear_dirty_pages()?; + sbox.mem_mgr.unwrap_mgr_mut().push_state(&vm_dirty_pages)?; Ok(sbox) } } diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index dcdd96589..43eb1adcd 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -241,7 +241,7 @@ mod tests { let new_mgr = || { let mut exe_info = simple_guest_exe_info().unwrap(); - let mut mgr = SandboxMemoryManager::load_guest_binary_into_memory( + let (mut mgr, _) = SandboxMemoryManager::load_guest_binary_into_memory( sandbox_cfg, &mut exe_info, None, @@ -356,7 +356,7 @@ mod tests { tracing::subscriber::with_default(subscriber.clone(), || { let new_mgr = || { let mut exe_info = simple_guest_exe_info().unwrap(); - let mut mgr = SandboxMemoryManager::load_guest_binary_into_memory( + let (mut mgr, _) = SandboxMemoryManager::load_guest_binary_into_memory( sandbox_cfg, &mut exe_info, None, diff --git a/src/hyperlight_host/src/sandbox/uninitialized.rs b/src/hyperlight_host/src/sandbox/uninitialized.rs index e27f91ff2..ae17f1081 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized.rs @@ -29,6 +29,7 @@ use crate::func::host_functions::{HostFunction, register_host_function}; use crate::func::{ParameterTuple, SupportedReturnType}; #[cfg(feature = "build-metadata")] use crate::log_build_details; +use crate::mem::dirty_page_tracking::DirtyPageTracker; use crate::mem::exe::ExeInfo; use crate::mem::memory_region::{DEFAULT_GUEST_BLOB_MEM_FLAGS, MemoryRegionFlags}; use crate::mem::mgr::{STACK_COOKIE_LEN, SandboxMemoryManager}; @@ -80,6 +81,7 @@ pub struct UninitializedSandbox { pub(crate) config: SandboxConfiguration, #[cfg(any(crashdump, gdb))] pub(crate) rt_cfg: SandboxRuntimeConfig, + pub(crate) tracker: Option, } impl crate::sandbox_state::sandbox::UninitializedSandbox for UninitializedSandbox { @@ -250,17 +252,15 @@ impl UninitializedSandbox { } }; - let mut mem_mgr_wrapper = { - let mut mgr = UninitializedSandbox::load_guest_binary( - sandbox_cfg, - &guest_binary, - guest_blob.as_ref(), - )?; + let (mut mgr, tracker) = UninitializedSandbox::load_guest_binary( + sandbox_cfg, + &guest_binary, + guest_blob.as_ref(), + )?; - let stack_guard = Self::create_stack_guard(); - mgr.set_stack_guard(&stack_guard)?; - MemMgrWrapper::new(mgr, stack_guard) - }; + let stack_guard = Self::create_stack_guard(); + mgr.set_stack_guard(&stack_guard)?; + let mut mem_mgr_wrapper = MemMgrWrapper::new(mgr, stack_guard); mem_mgr_wrapper.write_memory_layout()?; @@ -278,6 +278,7 @@ impl UninitializedSandbox { config: sandbox_cfg, #[cfg(any(crashdump, gdb))] rt_cfg, + tracker: Some(tracker), }; // If we were passed a writer for host print register it otherwise use the default. @@ -308,7 +309,10 @@ impl UninitializedSandbox { cfg: SandboxConfiguration, guest_binary: &GuestBinary, guest_blob: Option<&GuestBlob>, - ) -> Result> { + ) -> Result<( + SandboxMemoryManager, + DirtyPageTracker, + )> { let mut exe_info = match guest_binary { GuestBinary::FilePath(bin_path_str) => ExeInfo::from_file(bin_path_str)?, GuestBinary::Buffer(buffer) => ExeInfo::from_buf(buffer)?, @@ -396,6 +400,7 @@ impl UninitializedSandbox { Ok(()) } } + // Check to see if the current version of Windows is supported // Hyperlight is only supported on Windows 11 and Windows Server 2022 and later #[cfg(target_os = "windows")] diff --git a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs index a37f747e2..9ce36237b 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs @@ -59,7 +59,7 @@ use crate::{MultiUseSandbox, Result, UninitializedSandbox, log_then_return, new_ /// please reach out to a Hyperlight developer before making the change. #[instrument(err(Debug), skip_all, , parent = Span::current(), level = "Trace")] fn evolve_impl( - u_sbox: UninitializedSandbox, + mut u_sbox: UninitializedSandbox, transform: TransformFunc, ) -> Result where @@ -70,15 +70,20 @@ where Arc>, Arc>, RawPtr, + &[usize], // dirty host pages (indices, not bitmap) ) -> Result, { let (hshm, mut gshm) = u_sbox.mgr.build(); - let mut vm = set_up_hypervisor_partition( - &mut gshm, - &u_sbox.config, - #[cfg(any(crashdump, gdb))] - &u_sbox.rt_cfg, - )?; + + let tracker = match u_sbox.tracker.take() { + Some(tracker) => tracker, + None => { + return Err(new_error!( + "Failed to take tracker from UninitializedSandbox" + )); + } + }; + let outb_hdl = outb_handler_wrapper(hshm.clone(), u_sbox.host_funcs.clone()); let seed = { @@ -99,6 +104,18 @@ where #[cfg(target_os = "linux")] setup_signal_handlers(&u_sbox.config)?; + // before entering VM (and before mapping memory into VM), stop tracking dirty pages from the host side + let dirty_host_pages_idx = gshm + .get_shared_mem_mut() + .with_exclusivity(|e| e.stop_tracking_dirty_pages(tracker))??; + + let mut vm = set_up_hypervisor_partition( + &mut gshm, + &u_sbox.config, + #[cfg(any(crashdump, gdb))] + &u_sbox.rt_cfg, + )?; + vm.initialise( peb_addr, seed, @@ -122,6 +139,7 @@ where outb_hdl, mem_access_hdl, RawPtr::from(dispatch_function_addr), + &dirty_host_pages_idx, ) } @@ -129,9 +147,15 @@ where pub(super) fn evolve_impl_multi_use(u_sbox: UninitializedSandbox) -> Result { evolve_impl( u_sbox, - |hf, mut hshm, vm, out_hdl, mem_hdl, dispatch_ptr| { + |hf, mut hshm, mut vm, out_hdl, mem_hdl, dispatch_ptr, host_dirty_pages_idx| { { - hshm.as_mut().push_state()?; + let vm_dirty_pages_bitmap = vm.get_and_clear_dirty_pages()?; + let layout = hshm.unwrap_mgr().layout; + hshm.as_mut().create_initial_snapshot( + &vm_dirty_pages_bitmap, + host_dirty_pages_idx, + &layout, + )?; } Ok(MultiUseSandbox::from_uninit( hf, @@ -163,6 +187,7 @@ pub(crate) fn set_up_hypervisor_partition( #[cfg(not(feature = "init-paging"))] let rsp_ptr = GuestPtr::try_from(Offset::from(0))?; let regions = mgr.layout.get_memory_regions(&mgr.shared_mem)?; + let base_ptr = GuestPtr::try_from(Offset::from(0))?; let pml4_ptr = { let pml4_offset_u64 = u64::try_from(SandboxMemoryLayout::PML4_OFFSET)?; diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 2db32b4a0..1c59f773b 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -180,7 +180,7 @@ fn interrupt_same_thread() { } }); - for _ in 0..NUM_ITERS { + for i in 0..NUM_ITERS { barrier.wait(); sbox1 .call_guest_function_by_name::("Echo", "hello".to_string()) @@ -190,7 +190,7 @@ fn interrupt_same_thread() { // Only allow successful calls or interrupted. // The call can be successful in case the call is finished before kill() is called. } - _ => panic!("Unexpected return"), + Err(e) => panic!("Unexpected return from sandbox 2: {:?} iteration {}", e, i), }; sbox3 .call_guest_function_by_name::("Echo", "hello".to_string()) @@ -234,7 +234,7 @@ fn interrupt_same_thread_no_barrier() { // Only allow successful calls or interrupted. // The call can be successful in case the call is finished before kill() is called. } - _ => panic!("Unexpected return"), + Err(e) => panic!("Unexpected return from sandbox 2: {:?}", e), }; sbox3 .call_guest_function_by_name::("Echo", "hello".to_string()) @@ -786,3 +786,32 @@ fn log_test_messages(levelfilter: Option) { .unwrap(); } } + +#[test] +// Test to ensure that the state of a sandbox is reset after each function call +// This uses the simpleguest and calls the "echo" function 1000 times with a 64-character string +// The fact that we can successfully call the function 1000 times and get consistent +// results indicates that the sandbox state is being properly reset between calls. +// If there were state leaks, we would expect to see failures or inconsistent behavior +// as the calls accumulate, specifically the input buffer would fill up and cause an error +// if the default size of the input buffer is changed this test should be updated accordingly +fn sandbox_state_reset_between_calls() { + let mut sbox = new_uninit().unwrap().evolve(Noop::default()).unwrap(); + + // Create a 64-character test string + let test_string = "A".repeat(64); + + // Call the echo function 1000 times + for i in 0..1000 { + let result = sbox + .call_guest_function_by_name::("Echo", test_string.clone()) + .unwrap(); + + // Verify that the echo function returns the same string we sent + assert_eq!( + result, test_string, + "Echo function returned unexpected result on iteration {}: expected '{}', got '{}'", + i, test_string, result + ); + } +}