diff --git a/src/hyperlight_common/src/mem.rs b/src/hyperlight_common/src/mem.rs index 4e5448ae9..4dffc6c4a 100644 --- a/src/hyperlight_common/src/mem.rs +++ b/src/hyperlight_common/src/mem.rs @@ -17,6 +17,8 @@ limitations under the License. pub const PAGE_SHIFT: u64 = 12; pub const PAGE_SIZE: u64 = 1 << 12; pub const PAGE_SIZE_USIZE: usize = 1 << 12; +// The number of pages in 1 "block". A single u64 can be used as bitmap to keep track of all dirty pages in a block. +pub const PAGES_IN_BLOCK: usize = 64; /// A memory region in the guest address space #[derive(Debug, Clone, Copy)] diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index 96cc7ecf0..3eb3557cc 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -30,6 +30,20 @@ fn create_multiuse_sandbox() -> MultiUseSandbox { create_uninit_sandbox().evolve().unwrap() } +fn create_sandbox_with_heap_size(heap_size_mb: Option) -> MultiUseSandbox { + let path = simple_guest_as_string().unwrap(); + let config = if let Some(size_mb) = heap_size_mb { + let mut config = SandboxConfiguration::default(); + config.set_heap_size(size_mb * 1024 * 1024); // Convert MB to bytes + Some(config) + } else { + None + }; + + let uninit_sandbox = UninitializedSandbox::new(GuestBinary::FilePath(path), config).unwrap(); + uninit_sandbox.evolve().unwrap() +} + fn guest_call_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("guest_functions"); @@ -41,6 +55,29 @@ fn guest_call_benchmark(c: &mut Criterion) { b.iter(|| sbox.call::("Echo", "hello\n".to_string()).unwrap()); }); + // Benchmarks a single guest restore after a guest function call. + // The benchmark only includes the time to reset the sandbox memory after the call. + group.bench_function("guest_restore", |b| { + let mut sbox = create_multiuse_sandbox(); + let snapshot = sbox.snapshot().unwrap(); + + b.iter_custom(|iters| { + let mut total_duration = std::time::Duration::ZERO; + + for _ in 0..iters { + // Dirty some pages + sbox.call::("Echo", "hello\n".to_string()).unwrap(); + + // Measure only the restore operation + let start = std::time::Instant::now(); + sbox.restore(&snapshot).unwrap(); + total_duration += start.elapsed(); + } + + total_duration + }); + }); + // Benchmarks a single guest function call. // The benchmark does include the time to reset the sandbox memory after the call. group.bench_function("guest_call_with_restore", |b| { @@ -75,37 +112,57 @@ fn guest_call_benchmark(c: &mut Criterion) { group.finish(); } -fn guest_call_benchmark_large_param(c: &mut Criterion) { - let mut group = c.benchmark_group("guest_functions_with_large_parameters"); +fn guest_call_benchmark_large_params(c: &mut Criterion) { + let mut group = c.benchmark_group("2_large_parameters"); #[cfg(target_os = "windows")] group.sample_size(10); // This benchmark is very slow on Windows, so we reduce the sample size to avoid long test runs. - // This benchmark includes time to first clone a vector and string, so it is not a "pure' benchmark of the guest call, but it's still useful - group.bench_function("guest_call_with_large_parameters", |b| { - const SIZE: usize = 50 * 1024 * 1024; // 50 MB - let large_vec = vec![0u8; SIZE]; - let large_string = unsafe { String::from_utf8_unchecked(large_vec.clone()) }; // Safety: indeed above vec is valid utf8 + // Parameter sizes to test in MB. Each guest call will use two parameters of this size (vec and str). + const PARAM_SIZES_MB: &[u64] = &[5, 20, 60]; - let mut config = SandboxConfiguration::default(); - config.set_input_data_size(2 * SIZE + (1024 * 1024)); // 2 * SIZE + 1 MB, to allow 1MB for the rest of the serialized function call - config.set_heap_size(SIZE as u64 * 15); + for ¶m_size_mb in PARAM_SIZES_MB { + let benchmark_name = format!("guest_call_restore_{}mb_params", param_size_mb); + group.bench_function(&benchmark_name, |b| { + let param_size_bytes = param_size_mb * 1024 * 1024; - let sandbox = UninitializedSandbox::new( - GuestBinary::FilePath(simple_guest_as_string().unwrap()), - Some(config), - ) - .unwrap(); - let mut sandbox = sandbox.evolve().unwrap(); + let large_vec = vec![0u8; param_size_bytes as usize]; + let large_string = String::from_utf8(large_vec.clone()).unwrap(); - b.iter(|| { - sandbox - .call_guest_function_by_name::<()>( - "LargeParameters", - (large_vec.clone(), large_string.clone()), - ) - .unwrap() + let mut config = SandboxConfiguration::default(); + config.set_heap_size(600 * 1024 * 1024); + config.set_input_data_size(300 * 1024 * 1024); + + let sandbox = UninitializedSandbox::new( + GuestBinary::FilePath(simple_guest_as_string().unwrap()), + Some(config), + ) + .unwrap(); + let mut sandbox = sandbox.evolve().unwrap(); + let snapshot = sandbox.snapshot().unwrap(); + + // Iter_custom to avoid measure clone time of params + b.iter_custom(|iters| { + let mut total_duration = std::time::Duration::ZERO; + + for _ in 0..iters { + let vec_clone = large_vec.clone(); + let string_clone = large_string.clone(); + + let start = std::time::Instant::now(); + sandbox + .call_guest_function_by_name::<()>( + "LargeParameters", + (vec_clone, string_clone), + ) + .unwrap(); + sandbox.restore(&snapshot).unwrap(); + total_duration += start.elapsed(); + } + + total_duration + }); }); - }); + } group.finish(); } @@ -138,9 +195,58 @@ fn sandbox_benchmark(c: &mut Criterion) { group.finish(); } +fn sandbox_heap_size_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("sandbox_heap_sizes"); + + const HEAP_SIZES_MB: &[u64] = &[50, 500, 995]; + + // Benchmark sandbox creation with default heap size + group.bench_function("create_sandbox_default_heap", |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(None)); + }); + + // Benchmark sandbox creation with different heap sizes + for &heap_size_mb in HEAP_SIZES_MB { + let benchmark_name = format!("create_sandbox_{}mb_heap", heap_size_mb); + group.bench_function(&benchmark_name, |b| { + b.iter_with_large_drop(|| create_sandbox_with_heap_size(Some(heap_size_mb))); + }); + } + + group.finish(); +} + +fn guest_call_heap_size_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("guest_call_restore_heap_sizes"); + + const HEAP_SIZES_MB: &[Option] = + &[None, Some(50), Some(100), Some(250), Some(500), Some(995)]; + + // Benchmark guest function call with different heap sizes (including default) + for &heap_size_mb in HEAP_SIZES_MB { + let benchmark_name = match heap_size_mb { + None => "guest_call_restore_default_mb_heap".to_string(), + Some(size) => format!("guest_call_restore_{}_mb_heap", size), + }; + group.bench_function(&benchmark_name, |b| { + let mut sandbox = create_sandbox_with_heap_size(heap_size_mb); + let snapshot = sandbox.snapshot().unwrap(); + + b.iter(|| { + sandbox + .call_guest_function_by_name::("Echo", "hello\n".to_string()) + .unwrap(); + sandbox.restore(&snapshot).unwrap(); + }); + }); + } + + group.finish(); +} + criterion_group! { name = benches; config = Criterion::default(); - targets = guest_call_benchmark, sandbox_benchmark, guest_call_benchmark_large_param + targets = guest_call_benchmark, sandbox_benchmark, sandbox_heap_size_benchmark, guest_call_benchmark_large_params, guest_call_heap_size_benchmark } criterion_main!(benches); diff --git a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs index d9160b6f2..3ae707553 100644 --- a/src/hyperlight_host/src/hypervisor/hyperv_linux.rs +++ b/src/hyperlight_host/src/hypervisor/hyperv_linux.rs @@ -29,6 +29,8 @@ use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use log::{LevelFilter, error}; +#[cfg(mshv3)] +use mshv_bindings::MSHV_GPAP_ACCESS_OP_CLEAR; #[cfg(mshv2)] use mshv_bindings::hv_message; use mshv_bindings::{ @@ -76,6 +78,9 @@ use crate::sandbox::SandboxConfiguration; use crate::sandbox::uninitialized::SandboxRuntimeConfig; use crate::{Result, log_then_return, new_error}; +#[cfg(mshv2)] +const CLEAR_DIRTY_BIT_FLAG: u64 = 0b100; + #[cfg(gdb)] mod debug { use std::sync::{Arc, Mutex}; @@ -302,6 +307,7 @@ pub(crate) struct HypervLinuxDriver { vcpu_fd: VcpuFd, entrypoint: u64, mem_regions: Vec, + n_initial_regions: usize, orig_rsp: GuestPtr, interrupt_handle: Arc, @@ -351,6 +357,7 @@ impl HypervLinuxDriver { vm_fd.initialize()?; vm_fd }; + vm_fd.enable_dirty_page_tracking()?; let mut vcpu_fd = vm_fd.create_vcpu(0)?; @@ -391,13 +398,31 @@ impl HypervLinuxDriver { (None, None) }; + let mut base_pfn = u64::MAX; + let mut total_size: usize = 0; + mem_regions.iter().try_for_each(|region| { - let mshv_region = region.to_owned().into(); + let mshv_region: mshv_user_mem_region = region.to_owned().into(); + if base_pfn == u64::MAX { + base_pfn = mshv_region.guest_pfn; + } + total_size += mshv_region.size as usize; vm_fd.map_user_memory(mshv_region) })?; Self::setup_initial_sregs(&mut vcpu_fd, pml4_ptr.absolute()?)?; + // get/clear the dirty page bitmap, mshv sets all the bit dirty at initialization + // if we dont clear them then we end up taking a complete snapsot of memory page by page which gets + // progressively slower as the sandbox size increases + // the downside of doing this here is that the call to get_dirty_log will takes longer as the number of pages increase + // but for larger sandboxes its easily cheaper than copying all the pages + + #[cfg(mshv2)] + vm_fd.get_dirty_log(base_pfn, total_size, CLEAR_DIRTY_BIT_FLAG)?; + #[cfg(mshv3)] + vm_fd.get_dirty_log(base_pfn, total_size, MSHV_GPAP_ACCESS_OP_CLEAR as u8)?; + let interrupt_handle = Arc::new(LinuxInterruptHandle { running: AtomicU64::new(0), cancel_requested: AtomicBool::new(false), @@ -428,6 +453,7 @@ impl HypervLinuxDriver { page_size: 0, vm_fd, vcpu_fd, + n_initial_regions: mem_regions.len(), mem_regions, entrypoint: entrypoint_ptr.absolute()?, orig_rsp: rsp_ptr, @@ -885,6 +911,69 @@ impl Hypervisor for HypervLinuxDriver { self.interrupt_handle.clone() } + // TODO: Implement getting additional host-mapped dirty pages. + fn get_and_clear_dirty_pages(&mut self) -> Result> { + let first_mshv_region: mshv_user_mem_region = self + .mem_regions + .first() + .ok_or(new_error!( + "tried to get dirty page bitmap of 0-sized region" + ))? + .to_owned() + .into(); + + let n_contiguous = self + .mem_regions + .windows(2) + .take_while(|window| window[0].guest_region.end == window[1].guest_region.start) + .count() + + 1; // +1 because windows(2) gives us n-1 pairs for n regions + + if n_contiguous != self.n_initial_regions { + return Err(new_error!( + "get_and_clear_dirty_pages: not all regions are contiguous, expected {} but got {}", + self.n_initial_regions, + n_contiguous + )); + } + + let sandbox_total_size = self + .mem_regions + .iter() + .take(n_contiguous) + .map(|r| r.guest_region.len()) + .sum(); + + let mut sandbox_dirty_pages = self.vm_fd.get_dirty_log( + first_mshv_region.guest_pfn, + sandbox_total_size, + #[cfg(mshv2)] + CLEAR_DIRTY_BIT_FLAG, + #[cfg(mshv3)] + (MSHV_GPAP_ACCESS_OP_CLEAR as u8), + )?; + + // Sanitize bits beyond sandbox + // + // TODO: remove this once bug in mshv is fixed. The bug makes it possible + // for non-mapped memory to incorrectly be marked dirty. To fix this, we just zero out + // any bits that are not within the sandbox size. + let sandbox_pages = sandbox_total_size / self.page_size; + let last_block_idx = sandbox_dirty_pages.len().saturating_sub(1); + if let Some(last_block) = sandbox_dirty_pages.last_mut() { + let last_block_start_page = last_block_idx * 64; + let last_block_end_page = last_block_start_page + 64; + + // If the last block extends beyond the sandbox, clear the invalid bits + if last_block_end_page > sandbox_pages { + let valid_bits_in_last_block = sandbox_pages - last_block_start_page; + let mask = (1u64 << valid_bits_in_last_block) - 1; + *last_block &= mask; + } + } + Ok(sandbox_dirty_pages) + } + #[cfg(crashdump)] fn crashdump_context(&self) -> Result> { if self.rt_cfg.guest_core_dump { diff --git a/src/hyperlight_host/src/hypervisor/hyperv_windows.rs b/src/hyperlight_host/src/hypervisor/hyperv_windows.rs index a057c41cc..3ab57af4e 100644 --- a/src/hyperlight_host/src/hypervisor/hyperv_windows.rs +++ b/src/hyperlight_host/src/hypervisor/hyperv_windows.rs @@ -58,6 +58,7 @@ use super::{ use super::{HyperlightExit, Hypervisor, InterruptHandle, VirtualCPU}; use crate::hypervisor::fpu::FP_CONTROL_WORD_DEFAULT; use crate::hypervisor::wrappers::WHvGeneralRegisters; +use crate::mem::bitmap::new_page_bitmap; use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; use crate::mem::ptr::{GuestPtr, RawPtr}; #[cfg(crashdump)] @@ -615,13 +616,21 @@ impl Hypervisor for HypervWindowsDriver { Ok(()) } + fn get_and_clear_dirty_pages(&mut self) -> Result> { + // For now we just mark all pages dirty which is the equivalent of taking a full snapshot + let total_size = self.mem_regions.iter().map(|r| r.guest_region.len()).sum(); + new_page_bitmap(total_size, true) + } + #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] unsafe fn map_region(&mut self, _rgn: &MemoryRegion) -> Result<()> { - log_then_return!("Mapping host memory into the guest not yet supported on this platform"); + // TODO: when adding support, also update `get_and_clear_dirty_pages`, see kvm/mshv for details + log_then_return!("Mapping host memory into the guest not yet supported on this platform."); } #[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")] unsafe fn unmap_regions(&mut self, n: u64) -> Result<()> { + // TODO: when adding support, also update `get_and_clear_dirty_pages`, see kvm/mshv for details if n > 0 { log_then_return!( "Mapping host memory into the guest not yet supported on this platform" diff --git a/src/hyperlight_host/src/hypervisor/kvm.rs b/src/hyperlight_host/src/hypervisor/kvm.rs index 0802ecb6b..161a40f8e 100644 --- a/src/hyperlight_host/src/hypervisor/kvm.rs +++ b/src/hyperlight_host/src/hypervisor/kvm.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use std::sync::Mutex; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use hyperlight_common::mem::{PAGE_SIZE_USIZE, PAGES_IN_BLOCK}; use kvm_bindings::{kvm_fpu, kvm_regs, kvm_userspace_memory_region}; use kvm_ioctls::Cap::UserMemory; use kvm_ioctls::{Kvm, VcpuExit, VcpuFd, VmFd}; @@ -43,7 +44,8 @@ use super::{ use super::{HyperlightExit, Hypervisor, InterruptHandle, LinuxInterruptHandle, VirtualCPU}; #[cfg(gdb)] use crate::HyperlightError; -use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; +use crate::mem::bitmap::{bit_index_iterator, new_page_bitmap}; +use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags, MemoryRegionType}; use crate::mem::ptr::{GuestPtr, RawPtr}; use crate::sandbox::SandboxConfiguration; #[cfg(crashdump)] @@ -290,6 +292,7 @@ pub(crate) struct KVMDriver { entrypoint: u64, orig_rsp: GuestPtr, mem_regions: Vec, + n_initial_regions: usize, interrupt_handle: Arc, #[cfg(gdb)] @@ -372,6 +375,7 @@ impl KVMDriver { vcpu_fd, entrypoint, orig_rsp: rsp_gp, + n_initial_regions: mem_regions.len(), mem_regions, interrupt_handle: interrupt_handle.clone(), #[cfg(gdb)] @@ -750,6 +754,61 @@ impl Hypervisor for KVMDriver { self.interrupt_handle.clone() } + // TODO: Implement getting additional host-mapped dirty pages. + fn get_and_clear_dirty_pages(&mut self) -> Result> { + let n_contiguous = self + .mem_regions + .windows(2) + .take_while(|window| window[0].guest_region.end == window[1].guest_region.start) + .count() + + 1; // +1 because windows(2) gives us n-1 pairs for n regions + + if n_contiguous != self.n_initial_regions { + return Err(new_error!( + "get_and_clear_dirty_pages: not all regions are contiguous, expected {} but got {}", + self.n_initial_regions, + n_contiguous + )); + } + let mut page_indices = vec![]; + let mut current_page = 0; + + // Iterate over all memory regions and get the dirty pages for each region ignoring guard pages which cannot be dirty + for (i, mem_region) in self.mem_regions.iter().take(n_contiguous).enumerate() { + let num_pages = mem_region.guest_region.len() / PAGE_SIZE_USIZE; + let bitmap = match mem_region.flags { + MemoryRegionFlags::READ => { + // read-only page. It can never be dirty so return zero dirty pages. + new_page_bitmap(mem_region.guest_region.len(), false)? + } + _ => { + if mem_region.region_type == MemoryRegionType::GuardPage { + // Trying to get dirty pages for a guard page region results in a VMMSysError(2) + new_page_bitmap(mem_region.guest_region.len(), false)? + } else { + // Get the dirty bitmap for the memory region + self.vm_fd + .get_dirty_log(i as u32, mem_region.guest_region.len())? + } + } + }; + for page_idx in bit_index_iterator(&bitmap) { + page_indices.push(current_page + page_idx); + } + current_page += num_pages; + } + + // convert vec of page indices to vec of blocks + let mut sandbox_dirty_pages = new_page_bitmap(current_page * PAGE_SIZE_USIZE, false)?; + for page_idx in page_indices { + let block_idx = page_idx / PAGES_IN_BLOCK; + let bit_idx = page_idx % PAGES_IN_BLOCK; + sandbox_dirty_pages[block_idx] |= 1 << bit_idx; + } + + Ok(sandbox_dirty_pages) + } + #[cfg(crashdump)] fn crashdump_context(&self) -> Result> { if self.rt_cfg.guest_core_dump { diff --git a/src/hyperlight_host/src/hypervisor/mod.rs b/src/hyperlight_host/src/hypervisor/mod.rs index ecf6acbc5..acf744926 100644 --- a/src/hyperlight_host/src/hypervisor/mod.rs +++ b/src/hyperlight_host/src/hypervisor/mod.rs @@ -196,6 +196,12 @@ pub(crate) trait Hypervisor: Debug + Sync + Send { None } + /// Get dirty pages as a bitmap (Vec). + /// Each bit in a u64 represents a page. + /// This also clears the bitflags, marking the pages as non-dirty. + /// TODO: Implement getting additional host-mapped dirty pages. + fn get_and_clear_dirty_pages(&mut self) -> Result>; + /// Get InterruptHandle to underlying VM fn interrupt_handle(&self) -> Arc; diff --git a/src/hyperlight_host/src/mem/bitmap.rs b/src/hyperlight_host/src/mem/bitmap.rs new file mode 100644 index 000000000..37f33fb07 --- /dev/null +++ b/src/hyperlight_host/src/mem/bitmap.rs @@ -0,0 +1,296 @@ +/* +Copyright 2025 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +use std::cmp::Ordering; + +use hyperlight_common::mem::{PAGE_SIZE_USIZE, PAGES_IN_BLOCK}; +use termcolor::{Color, ColorChoice, ColorSpec, StandardStream, WriteColor}; + +use super::layout::SandboxMemoryLayout; +use crate::{Result, log_then_return}; + +// Contains various helper functions for dealing with bitmaps. + +/// Returns a new bitmap of pages. If `init_dirty` is true, all pages are marked as dirty, otherwise all pages are clean. +/// Will return an error if given size is 0. +pub fn new_page_bitmap(size_in_bytes: usize, init_dirty: bool) -> Result> { + if size_in_bytes == 0 { + log_then_return!("Tried to create a bitmap with size 0."); + } + let num_pages = size_in_bytes.div_ceil(PAGE_SIZE_USIZE); + let num_blocks = num_pages.div_ceil(PAGES_IN_BLOCK); + match init_dirty { + false => Ok(vec![0; num_blocks]), + true => { + let mut bitmap = vec![!0u64; num_blocks]; // all pages are dirty + let num_unused_bits = num_blocks * PAGES_IN_BLOCK - num_pages; + // set the unused bits to 0, could cause problems otherwise + #[allow(clippy::unwrap_used)] + let last_block = bitmap.last_mut().unwrap(); // unwrap is safe since size_in_bytes>0 + *last_block >>= num_unused_bits; + Ok(bitmap) + } + } +} + +/// Returns the union (bitwise OR) of two bitmaps. The resulting bitmap will have the same length +/// as the longer of the two input bitmaps. +pub(crate) fn bitmap_union(bitmap: &[u64], other_bitmap: &[u64]) -> Vec { + let min_len = bitmap.len().min(other_bitmap.len()); + let max_len = bitmap.len().max(other_bitmap.len()); + + let mut result = vec![0; max_len]; + + for i in 0..min_len { + result[i] = bitmap[i] | other_bitmap[i]; + } + + match bitmap.len().cmp(&other_bitmap.len()) { + Ordering::Greater => { + result[min_len..].copy_from_slice(&bitmap[min_len..]); + } + Ordering::Less => { + result[min_len..].copy_from_slice(&other_bitmap[min_len..]); + } + Ordering::Equal => {} + } + + result +} + +// Used as a helper struct to implement an iterator on. +struct SetBitIndices<'a> { + bitmap: &'a [u64], + block_index: usize, // one block is 1 u64, which is 64 pages + current: u64, // the current block we are iterating over, or 0 if first iteration +} + +/// Iterates over the zero-based indices of the set bits in the given bitmap. +pub(crate) fn bit_index_iterator(bitmap: &[u64]) -> impl Iterator + '_ { + SetBitIndices { + bitmap, + block_index: 0, + current: 0, + } +} + +impl Iterator for SetBitIndices<'_> { + type Item = usize; + + fn next(&mut self) -> Option { + while self.current == 0 { + // will always enter this on first iteration because current is initialized to 0 + if self.block_index >= self.bitmap.len() { + // no more blocks to iterate over + return None; + } + self.current = self.bitmap[self.block_index]; + self.block_index += 1; + } + let trailing_zeros = self.current.trailing_zeros(); + self.current &= self.current - 1; // Clear the least significant set bit + Some((self.block_index - 1) * 64 + trailing_zeros as usize) // block_index guaranteed to be > 0 at this point + } +} + +// Unused but useful for debugging +// Prints the dirty bitmap in a human-readable format, coloring each page according to its region +// NOTE: Might need to be updated if the memory layout changes +#[allow(dead_code)] +pub(crate) fn print_dirty_bitmap(bitmap: &[u64], layout: &SandboxMemoryLayout) { + let mut stdout = StandardStream::stdout(ColorChoice::Auto); + + // Helper function to determine which memory region a page belongs to + fn get_region_info(page_index: usize, layout: &SandboxMemoryLayout) -> (&'static str, Color) { + let page_offset = page_index * PAGE_SIZE_USIZE; + + // Check each memory region in order, using available methods and approximations + if page_offset >= layout.init_data_offset { + ("INIT_DATA", Color::Ansi256(129)) // Purple + } else if page_offset >= layout.get_top_of_user_stack_offset() { + ("STACK", Color::Ansi256(208)) // Orange + } else if page_offset >= layout.get_guard_page_offset() { + ("GUARD_PAGE", Color::White) + } else if page_offset >= layout.guest_heap_buffer_offset { + ("HEAP", Color::Red) + } else if page_offset >= layout.output_data_buffer_offset { + ("OUTPUT_DATA", Color::Green) + } else if page_offset >= layout.input_data_buffer_offset { + ("INPUT_DATA", Color::Blue) + } else if page_offset >= layout.host_function_definitions_buffer_offset { + ("HOST_FUNC_DEF", Color::Cyan) + } else if page_offset >= layout.peb_address { + ("PEB", Color::Magenta) + } else if page_offset >= layout.get_guest_code_offset() { + ("CODE", Color::Yellow) + } else { + // Everything up to and including guest code should be PAGE_TABLES + ("PAGE_TABLES", Color::Ansi256(14)) // Bright cyan + } + } + + let mut num_dirty_pages = 0; + for &block in bitmap.iter() { + num_dirty_pages += block.count_ones() as usize; + } + + for (i, &block) in bitmap.iter().enumerate() { + if block != 0 { + print!("Block {:3}: ", i); + + // Print each bit in the block with appropriate color + for bit_pos in 0..64 { + let bit_mask = 1u64 << bit_pos; + let page_index = i * 64 + bit_pos; + let (_region_name, color) = get_region_info(page_index, layout); + + let mut color_spec = ColorSpec::new(); + color_spec.set_fg(Some(color)); + + if block & bit_mask != 0 { + // Make 1s bold with dark background to stand out from 0s + color_spec.set_bold(true).set_bg(Some(Color::Black)); + let _ = stdout.set_color(&color_spec); + print!("1"); + } else { + // 0s are colored but not bold, no background + let _ = stdout.set_color(&color_spec); + print!("0"); + } + let _ = stdout.reset(); + } + + // Print a legend for this block showing which regions are represented + let mut regions_in_block = std::collections::HashMap::new(); + for bit_pos in 0..64 { + let bit_mask = 1u64 << bit_pos; + if block & bit_mask != 0 { + let page_index = i * 64 + bit_pos; + let (region_name, color) = get_region_info(page_index, layout); + regions_in_block.insert(region_name, color); + } + } + + if !regions_in_block.is_empty() { + print!(" ["); + let mut sorted_regions: Vec<_> = regions_in_block.iter().collect(); + sorted_regions.sort_by_key(|(name, _)| *name); + for (i, (region_name, color)) in sorted_regions.iter().enumerate() { + if i > 0 { + print!(", "); + } + let mut color_spec = ColorSpec::new(); + color_spec.set_fg(Some(**color)).set_bold(true); + let _ = stdout.set_color(&color_spec); + print!("{}", region_name); + let _ = stdout.reset(); + } + print!("]"); + } + println!(); + } + } + // Print the total number of dirty pages + println!("Total dirty pages: {}", num_dirty_pages); +} + +#[cfg(test)] +mod tests { + use hyperlight_common::mem::PAGE_SIZE_USIZE; + + use crate::Result; + use crate::mem::bitmap::{bit_index_iterator, bitmap_union, new_page_bitmap}; + + #[test] + fn new_page_bitmap_test() -> Result<()> { + let bitmap = new_page_bitmap(1, false)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 0); + + let bitmap = new_page_bitmap(1, true)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 1); + + let bitmap = new_page_bitmap(32 * PAGE_SIZE_USIZE, false)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 0); + + let bitmap = new_page_bitmap(32 * PAGE_SIZE_USIZE, true)?; + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap[0], 0x0000_0000_FFFF_FFFF); + Ok(()) + } + + #[test] + fn page_iterator() { + let data = vec![0b1000010100, 0b01, 0b100000000000000011]; + let mut iter = bit_index_iterator(&data); + assert_eq!(iter.next(), Some(2)); + assert_eq!(iter.next(), Some(4)); + assert_eq!(iter.next(), Some(9)); + assert_eq!(iter.next(), Some(64)); + assert_eq!(iter.next(), Some(128)); + assert_eq!(iter.next(), Some(129)); + assert_eq!(iter.next(), Some(145)); + assert_eq!(iter.next(), None); + + let data_2 = vec![0, 0, 0]; + let mut iter_2 = bit_index_iterator(&data_2); + assert_eq!(iter_2.next(), None); + + let data_3 = vec![0, 0, 0b1, 1 << 63]; + let mut iter_3 = bit_index_iterator(&data_3); + assert_eq!(iter_3.next(), Some(128)); + assert_eq!(iter_3.next(), Some(255)); + assert_eq!(iter_3.next(), None); + + let data_4 = vec![]; + let mut iter_4 = bit_index_iterator(&data_4); + assert_eq!(iter_4.next(), None); + } + + #[test] + fn union() -> Result<()> { + let a = 0b1000010100; + let b = 0b01; + let c = 0b100000000000000011; + let d = 0b101010100000011000000011; + let e = 0b000000000000001000000000000000000000; + let f = 0b100000000000000001010000000001010100000000000; + let bitmap = vec![a, b, c]; + let other_bitmap = vec![d, e, f]; + let union = bitmap_union(&bitmap, &other_bitmap); + assert_eq!(union, vec![a | d, b | e, c | f]); + + // different length + let union = bitmap_union(&[a], &[d, e, f]); + assert_eq!(union, vec![a | d, e, f]); + + let union = bitmap_union(&[a, b, c], &[d]); + assert_eq!(union, vec![a | d, b, c]); + + let union = bitmap_union(&[], &[d, e]); + assert_eq!(union, vec![d, e]); + + let union = bitmap_union(&[a, b, c], &[]); + assert_eq!(union, vec![a, b, c]); + + let union = bitmap_union(&[], &[]); + let empty: Vec = vec![]; + assert_eq!(union, empty); + + Ok(()) + } +} diff --git a/src/hyperlight_host/src/mem/elf.rs b/src/hyperlight_host/src/mem/elf.rs index 3efe09b4f..84bbc2b7b 100644 --- a/src/hyperlight_host/src/mem/elf.rs +++ b/src/hyperlight_host/src/mem/elf.rs @@ -24,6 +24,7 @@ use goblin::elf32::program_header::PT_LOAD; #[cfg(feature = "init-paging")] use goblin::elf64::program_header::PT_LOAD; +use super::shared_mem::ExclusiveSharedMemory; use crate::{Result, log_then_return, new_error}; pub(crate) struct ElfInfo { @@ -73,15 +74,26 @@ impl ElfInfo { .unwrap(); (max_phdr.p_vaddr + max_phdr.p_memsz - self.get_base_va()) as usize } - pub(crate) fn load_at(&self, load_addr: usize, target: &mut [u8]) -> Result<()> { + pub(crate) fn load_at( + &self, + load_addr: usize, + guest_code_offset: usize, + excl: &mut ExclusiveSharedMemory, + ) -> Result<()> { let base_va = self.get_base_va(); for phdr in self.phdrs.iter().filter(|phdr| phdr.p_type == PT_LOAD) { let start_va = (phdr.p_vaddr - base_va) as usize; let payload_offset = phdr.p_offset as usize; let payload_len = phdr.p_filesz as usize; - target[start_va..start_va + payload_len] - .copy_from_slice(&self.payload[payload_offset..payload_offset + payload_len]); - target[start_va + payload_len..start_va + phdr.p_memsz as usize].fill(0); + excl.copy_from_slice( + &self.payload[payload_offset..payload_offset + payload_len], + guest_code_offset + start_va, + )?; + + excl.zero_fill( + guest_code_offset + start_va + payload_len, + phdr.p_memsz as usize - payload_len, + )?; } let get_addend = |name, r: &Reloc| { r.r_addend @@ -104,8 +116,10 @@ impl ElfInfo { match r.r_type { R_X86_64_RELATIVE => { let addend = get_addend("R_X86_64_RELATIVE", r)?; - target[r.r_offset as usize..r.r_offset as usize + 8] - .copy_from_slice(&(load_addr as i64 + addend).to_le_bytes()); + excl.copy_from_slice( + &(load_addr as i64 + addend).to_le_bytes(), + guest_code_offset + r.r_offset as usize, + )?; } R_X86_64_NONE => {} _ => { diff --git a/src/hyperlight_host/src/mem/exe.rs b/src/hyperlight_host/src/mem/exe.rs index bf1724317..de2b9d28d 100644 --- a/src/hyperlight_host/src/mem/exe.rs +++ b/src/hyperlight_host/src/mem/exe.rs @@ -20,6 +20,7 @@ use std::vec::Vec; use super::elf::ElfInfo; use super::ptr_offset::Offset; +use super::shared_mem::ExclusiveSharedMemory; use crate::Result; // This is used extremely infrequently, so being unusually large for PE @@ -71,10 +72,15 @@ impl ExeInfo { // copying into target, but the PE loader chooses to apply // relocations in its owned representation of the PE contents, // which requires it to be &mut. - pub fn load(&mut self, load_addr: usize, target: &mut [u8]) -> Result<()> { + pub fn load( + &mut self, + load_addr: usize, + guest_code_offset: usize, + target: &mut ExclusiveSharedMemory, + ) -> Result<()> { match self { ExeInfo::Elf(elf) => { - elf.load_at(load_addr, target)?; + elf.load_at(load_addr, guest_code_offset, target)?; } } Ok(()) diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index 04edc9bcc..df8173747 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -111,10 +111,10 @@ pub(crate) struct SandboxMemoryLayout { pub(crate) host_function_definitions_buffer_offset: usize, pub(super) input_data_buffer_offset: usize, pub(super) output_data_buffer_offset: usize, - guest_heap_buffer_offset: usize, + pub(super) guest_heap_buffer_offset: usize, guard_page_offset: usize, guest_user_stack_buffer_offset: usize, // the lowest address of the user stack - init_data_offset: usize, + pub(super) init_data_offset: usize, // other pub(crate) peb_address: usize, diff --git a/src/hyperlight_host/src/mem/memory_region.rs b/src/hyperlight_host/src/mem/memory_region.rs index b46426c3b..6b2c04f42 100644 --- a/src/hyperlight_host/src/mem/memory_region.rs +++ b/src/hyperlight_host/src/mem/memory_region.rs @@ -31,7 +31,7 @@ use bitflags::bitflags; use hyperlight_common::mem::PAGE_SHIFT; use hyperlight_common::mem::PAGE_SIZE_USIZE; #[cfg(kvm)] -use kvm_bindings::{KVM_MEM_READONLY, kvm_userspace_memory_region}; +use kvm_bindings::{KVM_MEM_LOG_DIRTY_PAGES, KVM_MEM_READONLY, kvm_userspace_memory_region}; #[cfg(mshv2)] use mshv_bindings::{ HV_MAP_GPA_EXECUTABLE, HV_MAP_GPA_PERMISSIONS_NONE, HV_MAP_GPA_READABLE, HV_MAP_GPA_WRITABLE, @@ -326,7 +326,7 @@ impl From for kvm_bindings::kvm_userspace_memory_region { userspace_addr: region.host_region.start as u64, flags: match perm_flags { MemoryRegionFlags::READ => KVM_MEM_READONLY, - _ => 0, // normal, RWX + _ => KVM_MEM_LOG_DIRTY_PAGES, // normal, RWX }, } } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index f6ce32c26..4ac19ad69 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -15,6 +15,7 @@ limitations under the License. */ use std::cmp::Ordering; +use std::sync::Arc; use hyperlight_common::flatbuffer_wrappers::function_call::{ FunctionCall, validate_guest_function_call_buffer, @@ -73,6 +74,8 @@ pub(crate) struct SandboxMemoryManager { pub(crate) entrypoint_offset: Offset, /// How many memory regions were mapped after sandbox creation pub(crate) mapped_rgns: u64, + /// Most recent snapshot taken, in other words, the most recent state that `self` has been in (disregarding currently dirty pages) + pub(crate) most_recent_snapshot: Option>, } impl SandboxMemoryManager @@ -93,6 +96,7 @@ where load_addr, entrypoint_offset, mapped_rgns: 0, + most_recent_snapshot: None, } } @@ -259,8 +263,18 @@ where } } - pub(crate) fn snapshot(&mut self) -> Result { - SharedMemorySnapshot::new(&mut self.shared_mem, self.mapped_rgns) + pub(crate) fn snapshot( + &mut self, + dirty_pages_bitmap: &[u64], + ) -> Result> { + let snapshot = Arc::new(SharedMemorySnapshot::new( + &mut self.shared_mem, + dirty_pages_bitmap, + self.mapped_rgns, + self.most_recent_snapshot.clone(), + )?); + self.most_recent_snapshot = Some(snapshot.clone()); + Ok(snapshot) } /// This function restores a memory snapshot from a given snapshot. @@ -268,16 +282,21 @@ where /// Returns the number of memory regions mapped into the sandbox /// that need to be unmapped in order for the restore to be /// completed. - pub(crate) fn restore_snapshot(&mut self, snapshot: &SharedMemorySnapshot) -> Result { - if self.shared_mem.mem_size() != snapshot.mem_size() { - return Err(new_error!( - "Snapshot size does not match current memory size: {} != {}", - self.shared_mem.raw_mem_size(), - snapshot.mem_size() - )); - } + pub(crate) fn restore_snapshot( + &mut self, + snapshot: &Arc, + dirty_pages_bitmap: &[u64], + ) -> Result { let old_rgns = self.mapped_rgns; - self.mapped_rgns = snapshot.restore_from_snapshot(&mut self.shared_mem)?; + self.mapped_rgns = snapshot.restore_from_snapshot( + &mut self.shared_mem, + dirty_pages_bitmap, + &self.most_recent_snapshot, + )?; + + // Update the most recent snapshot to the one we just restored to + self.most_recent_snapshot = Some(snapshot.clone()); + Ok(old_rgns - self.mapped_rgns) } @@ -341,7 +360,8 @@ impl SandboxMemoryManager { exe_info.load( load_addr.clone().try_into()?, - &mut shared_mem.as_mut_slice()[layout.get_guest_code_offset()..], + layout.get_guest_code_offset(), + &mut shared_mem, )?; Ok(Self::new(layout, shared_mem, load_addr, entrypoint_offset)) @@ -406,6 +426,7 @@ impl SandboxMemoryManager { load_addr: self.load_addr.clone(), entrypoint_offset: self.entrypoint_offset, mapped_rgns: 0, + most_recent_snapshot: self.most_recent_snapshot.clone(), }, SandboxMemoryManager { shared_mem: gshm, @@ -413,6 +434,7 @@ impl SandboxMemoryManager { load_addr: self.load_addr.clone(), entrypoint_offset: self.entrypoint_offset, mapped_rgns: 0, + most_recent_snapshot: self.most_recent_snapshot.clone(), }, ) } diff --git a/src/hyperlight_host/src/mem/mod.rs b/src/hyperlight_host/src/mem/mod.rs index 1bcc03eae..4ae43614f 100644 --- a/src/hyperlight_host/src/mem/mod.rs +++ b/src/hyperlight_host/src/mem/mod.rs @@ -14,6 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. */ +/// Various helper functions for working with bitmaps +pub(crate) mod bitmap; /// A simple ELF loader pub(crate) mod elf; /// A generic wrapper for executable files (PE, ELF, etc) diff --git a/src/hyperlight_host/src/mem/shared_mem.rs b/src/hyperlight_host/src/mem/shared_mem.rs index 50c809f44..eec000f85 100644 --- a/src/hyperlight_host/src/mem/shared_mem.rs +++ b/src/hyperlight_host/src/mem/shared_mem.rs @@ -19,9 +19,9 @@ use std::ffi::c_void; use std::io::Error; #[cfg(target_os = "linux")] use std::ptr::null_mut; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; -use hyperlight_common::mem::PAGE_SIZE_USIZE; +use hyperlight_common::mem::{PAGE_SIZE_USIZE, PAGES_IN_BLOCK}; use tracing::{Span, instrument}; #[cfg(target_os = "windows")] use windows::Win32::Foundation::{CloseHandle, HANDLE, INVALID_HANDLE_VALUE}; @@ -78,8 +78,10 @@ macro_rules! generate_writer { #[allow(dead_code)] pub(crate) fn $fname(&mut self, offset: usize, value: $ty) -> Result<()> { let data = self.as_mut_slice(); - bounds_check!(offset, std::mem::size_of::<$ty>(), data.len()); - data[offset..offset + std::mem::size_of::<$ty>()].copy_from_slice(&value.to_le_bytes()); + let size = std::mem::size_of::<$ty>(); + bounds_check!(offset, size, data.len()); + data[offset..offset + size].copy_from_slice(&value.to_le_bytes()); + self.mark_pages_dirty(offset, size)?; Ok(()) } }; @@ -133,6 +135,7 @@ impl Drop for HostMapping { #[derive(Debug)] pub struct ExclusiveSharedMemory { region: Arc, + dirty_page_tracker: Arc>>, } unsafe impl Send for ExclusiveSharedMemory {} @@ -147,6 +150,8 @@ unsafe impl Send for ExclusiveSharedMemory {} #[derive(Debug)] pub struct GuestSharedMemory { region: Arc, + dirty_page_tracker: Arc>>, + /// The lock that indicates this shared memory is being used by non-Rust code /// /// This lock _must_ be held whenever the guest is executing, @@ -298,6 +303,8 @@ unsafe impl Send for GuestSharedMemory {} #[derive(Clone, Debug)] pub struct HostSharedMemory { region: Arc, + dirty_page_tracker: Arc>>, + lock: Arc>, } unsafe impl Send for HostSharedMemory {} @@ -316,6 +323,7 @@ impl ExclusiveSharedMemory { }; use crate::error::HyperlightError::{MemoryRequestTooBig, MmapFailed, MprotectFailed}; + use crate::mem::bitmap::new_page_bitmap; if min_size_bytes == 0 { return Err(new_error!("Cannot create shared memory with size 0")); @@ -370,23 +378,39 @@ impl ExclusiveSharedMemory { return Err(MprotectFailed(Error::last_os_error().raw_os_error())); } + // HostMapping is only non-Send/Sync because raw pointers + // are not ("as a lint", as the Rust docs say). We don't + // want to mark HostMapping Send/Sync immediately, because + // that could socially imply that it's "safe" to use + // unsafe accesses from multiple threads at once. Instead, we + // directly impl Send and Sync on this type. Since this + // type does have Send and Sync manually impl'd, the Arc + // is not pointless as the lint suggests. + #[allow(clippy::arc_with_non_send_sync)] + let host_mapping = Arc::new(HostMapping { + ptr: addr as *mut u8, + size: total_size, + }); + + let dirty_page_tracker = new_page_bitmap(min_size_bytes, false)?; + Ok(Self { - // HostMapping is only non-Send/Sync because raw pointers - // are not ("as a lint", as the Rust docs say). We don't - // want to mark HostMapping Send/Sync immediately, because - // that could socially imply that it's "safe" to use - // unsafe accesses from multiple threads at once. Instead, we - // directly impl Send and Sync on this type. Since this - // type does have Send and Sync manually impl'd, the Arc - // is not pointless as the lint suggests. - #[allow(clippy::arc_with_non_send_sync)] - region: Arc::new(HostMapping { - ptr: addr as *mut u8, - size: total_size, - }), + region: host_mapping, + dirty_page_tracker: Arc::new(Mutex::new(dirty_page_tracker)), }) } + /// Gets the dirty bitmap and then clears it in self. + pub(crate) fn get_and_clear_dirty_pages(&mut self) -> Result> { + let mut guard = self + .dirty_page_tracker + .try_lock() + .map_err(|_| new_error!("Failed to acquire lock on dirty page tracker"))?; + let bitmap = guard.clone(); + guard.fill(0); + Ok(bitmap) + } + /// Create a new region of shared memory with the given minimum /// size in bytes. The region will be surrounded by guard pages. /// @@ -394,6 +418,8 @@ impl ExclusiveSharedMemory { #[cfg(target_os = "windows")] #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn new(min_size_bytes: usize) -> Result { + use super::bitmap::new_page_bitmap; + if min_size_bytes == 0 { return Err(new_error!("Cannot create shared memory with size 0")); } @@ -484,21 +510,26 @@ impl ExclusiveSharedMemory { log_then_return!(WindowsAPIError(e.clone())); } + // HostMapping is only non-Send/Sync because raw pointers + // are not ("as a lint", as the Rust docs say). We don't + // want to mark HostMapping Send/Sync immediately, because + // that could socially imply that it's "safe" to use + // unsafe accesses from multiple threads at once. Instead, we + // directly impl Send and Sync on this type. Since this + // type does have Send and Sync manually impl'd, the Arc + // is not pointless as the lint suggests. + #[allow(clippy::arc_with_non_send_sync)] + let host_mapping = Arc::new(HostMapping { + ptr: addr.Value as *mut u8, + size: total_size, + handle, + }); + + let dirty_page_tracker = new_page_bitmap(min_size_bytes, false)?; + Ok(Self { - // HostMapping is only non-Send/Sync because raw pointers - // are not ("as a lint", as the Rust docs say). We don't - // want to mark HostMapping Send/Sync immediately, because - // that could socially imply that it's "safe" to use - // unsafe accesses from multiple threads at once. Instead, we - // directly impl Send and Sync on this type. Since this - // type does have Send and Sync manually impl'd, the Arc - // is not pointless as the lint suggests. - #[allow(clippy::arc_with_non_send_sync)] - region: Arc::new(HostMapping { - ptr: addr.Value as *mut u8, - size: total_size, - handle, - }), + region: host_mapping, + dirty_page_tracker: Arc::new(Mutex::new(dirty_page_tracker)), }) } @@ -576,7 +607,10 @@ impl ExclusiveSharedMemory { /// the safety documentation of pointer::offset. /// /// This is ensured by a check in ::new() - pub(super) fn as_mut_slice(&mut self) -> &mut [u8] { + /// + /// Additionally, writes to the returned slice will not mark pages as dirty. + /// User must call `mark_pages_dirty` manually to mark pages as dirty. + fn as_mut_slice(&mut self) -> &mut [u8] { unsafe { std::slice::from_raw_parts_mut(self.base_ptr(), self.mem_size()) } } @@ -610,6 +644,16 @@ impl ExclusiveSharedMemory { let data = self.as_mut_slice(); bounds_check!(offset, src.len(), data.len()); data[offset..offset + src.len()].copy_from_slice(src); + self.mark_pages_dirty(offset, src.len())?; + Ok(()) + } + + /// Copies bytes from `self` to `dst` starting at offset + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub fn copy_to_slice(&self, dst: &mut [u8], offset: usize) -> Result<()> { + let data = self.as_slice(); + bounds_check!(offset, dst.len(), data.len()); + dst.copy_from_slice(&data[offset..offset + dst.len()]); Ok(()) } @@ -621,6 +665,40 @@ impl ExclusiveSharedMemory { Ok(self.base_addr() + offset) } + /// Fill the memory in the range `[offset, offset + len)` with `value` + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub fn zero_fill(&mut self, offset: usize, len: usize) -> Result<()> { + bounds_check!(offset, len, self.mem_size()); + let data = self.as_mut_slice(); + data[offset..offset + len].fill(0); + self.mark_pages_dirty(offset, len)?; + Ok(()) + } + + /// Same as `copy_from_slice` but doesn't dirty the pages. + /// # Safety + /// This function is unsafe because it does not mark the pages as dirty. + /// Only use this if you are certain that the pages do not need to be marked as dirty. + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub unsafe fn restore_copy_from_slice(&mut self, src: &[u8], offset: usize) -> Result<()> { + let data = self.as_mut_slice(); + bounds_check!(offset, src.len(), data.len()); + data[offset..offset + src.len()].copy_from_slice(src); + Ok(()) + } + + /// Same as `zero_fill` but doesn't dirty the pages. + /// # Safety + /// This function is unsafe because it does not mark the pages as dirty. + /// Only use this if you are certain that the pages do not need to be marked as dirty. + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub(crate) unsafe fn restore_zero_fill(&mut self, offset: usize, len: usize) -> Result<()> { + bounds_check!(offset, len, self.mem_size()); + let data = self.as_mut_slice(); + data[offset..offset + len].fill(0); + Ok(()) + } + generate_reader!(read_u8, u8); generate_reader!(read_i8, i8); generate_reader!(read_u16, u16); @@ -654,15 +732,35 @@ impl ExclusiveSharedMemory { ( HostSharedMemory { region: self.region.clone(), + dirty_page_tracker: self.dirty_page_tracker.clone(), lock: lock.clone(), }, GuestSharedMemory { region: self.region.clone(), + dirty_page_tracker: self.dirty_page_tracker.clone(), lock: lock.clone(), }, ) } + /// Marks pages that cover bytes [offset, offset + size) as dirty + pub(super) fn mark_pages_dirty(&mut self, offset: usize, size: usize) -> Result<()> { + bounds_check!(offset, size, self.mem_size()); + let mut bitmap = self + .dirty_page_tracker + .try_lock() + .map_err(|_| new_error!("Failed to lock dirty page tracker"))?; + + let start_page = offset / PAGE_SIZE_USIZE; + let end_page = (offset + size - 1) / PAGE_SIZE_USIZE; // offset + size - 1 is the last affected byte. + for page_idx in start_page..=end_page { + let block_idx = page_idx / PAGES_IN_BLOCK; + let bit_idx = page_idx % PAGES_IN_BLOCK; + bitmap[block_idx] |= 1 << bit_idx; + } + Ok(()) + } + /// Gets the file handle of the shared memory region for this Sandbox #[cfg(target_os = "windows")] pub fn get_mmap_file_handle(&self) -> HANDLE { @@ -740,6 +838,7 @@ impl SharedMemory for GuestSharedMemory { fn region(&self) -> &HostMapping { &self.region } + fn with_exclusivity T>( &mut self, f: F, @@ -750,6 +849,7 @@ impl SharedMemory for GuestSharedMemory { .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; let mut excl = ExclusiveSharedMemory { region: self.region.clone(), + dirty_page_tracker: self.dirty_page_tracker.clone(), }; let ret = f(&mut excl); drop(excl); @@ -800,7 +900,7 @@ impl HostSharedMemory { /// Write a value of type T, whose representation is the same /// between the sandbox and the host, and which has no invalid bit /// patterns - pub fn write(&self, offset: usize, data: T) -> Result<()> { + pub fn write(&mut self, offset: usize, data: T) -> Result<()> { bounds_check!(offset, std::mem::size_of::(), self.mem_size()); unsafe { let slice: &[u8] = core::slice::from_raw_parts( @@ -809,6 +909,7 @@ impl HostSharedMemory { ); self.copy_from_slice(slice, offset)?; } + self.mark_pages_dirty(offset, std::mem::size_of::())?; Ok(()) } @@ -831,9 +932,8 @@ impl HostSharedMemory { Ok(()) } - /// Copy the contents of the sandbox at the specified offset into - /// the slice - pub fn copy_from_slice(&self, slice: &[u8], offset: usize) -> Result<()> { + /// Copy the contents of the given slice into self + pub fn copy_from_slice(&mut self, slice: &[u8], offset: usize) -> Result<()> { bounds_check!(offset, slice.len(), self.mem_size()); let base = self.base_ptr().wrapping_add(offset); let guard = self @@ -847,6 +947,7 @@ impl HostSharedMemory { } } drop(guard); + self.mark_pages_dirty(offset, slice.len())?; Ok(()) } @@ -864,6 +965,7 @@ impl HostSharedMemory { unsafe { base.wrapping_add(i).write_volatile(value) }; } drop(guard); + self.mark_pages_dirty(offset, len)?; Ok(()) } @@ -976,12 +1078,31 @@ impl HostSharedMemory { Ok(to_return) } + + /// Marks pages that cover bytes [offset, offset + size) as dirty + pub(super) fn mark_pages_dirty(&mut self, offset: usize, size: usize) -> Result<()> { + bounds_check!(offset, size, self.mem_size()); + let mut bitmap = self + .dirty_page_tracker + .try_lock() + .map_err(|_| new_error!("Failed to lock dirty page tracker"))?; + + let start_page = offset / PAGE_SIZE_USIZE; + let end_page = (offset + size - 1) / PAGE_SIZE_USIZE; // offset + size - 1 is the last affected byte. + for page_idx in start_page..=end_page { + let block_idx = page_idx / PAGES_IN_BLOCK; + let bit_idx = page_idx % PAGES_IN_BLOCK; + bitmap[block_idx] |= 1 << bit_idx; + } + Ok(()) + } } impl SharedMemory for HostSharedMemory { fn region(&self) -> &HostMapping { &self.region } + fn with_exclusivity T>( &mut self, f: F, @@ -992,6 +1113,7 @@ impl SharedMemory for HostSharedMemory { .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?; let mut excl = ExclusiveSharedMemory { region: self.region.clone(), + dirty_page_tracker: self.dirty_page_tracker.clone(), }; let ret = f(&mut excl); drop(excl); @@ -1045,7 +1167,7 @@ mod tests { let mem_size: usize = 4096; let vec_len = 10; let eshm = ExclusiveSharedMemory::new(mem_size)?; - let (hshm, _) = eshm.build(); + let (mut hshm, _) = eshm.build(); let vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; // write the value to the memory at the beginning. hshm.copy_from_slice(&vec, 0)?; @@ -1132,8 +1254,8 @@ mod tests { #[test] fn clone() { let eshm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE).unwrap(); - let (hshm1, _) = eshm.build(); - let hshm2 = hshm1.clone(); + let (mut hshm1, _) = eshm.build(); + let mut hshm2 = hshm1.clone(); // after hshm1 is cloned, hshm1 and hshm2 should have identical // memory sizes and pointers. diff --git a/src/hyperlight_host/src/mem/shared_mem_snapshot.rs b/src/hyperlight_host/src/mem/shared_mem_snapshot.rs index b7f461716..f49addac3 100644 --- a/src/hyperlight_host/src/mem/shared_mem_snapshot.rs +++ b/src/hyperlight_host/src/mem/shared_mem_snapshot.rs @@ -14,8 +14,13 @@ See the License for the specific language governing permissions and limitations under the License. */ +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use hyperlight_common::mem::PAGE_SIZE_USIZE; use tracing::{Span, instrument}; +use super::bitmap::bit_index_iterator; use super::shared_mem::SharedMemory; use crate::Result; @@ -23,85 +28,630 @@ use crate::Result; /// of the memory therein #[derive(Clone)] pub(crate) struct SharedMemorySnapshot { - snapshot: Vec, + /// Data (pages) in this snapshot + data: HashMap>, // page_number -> page_data. Each entry is 1 page /// How many non-main-RAM regions were mapped when this snapshot was taken? mapped_rgns: u64, + /// Parent snapshot (or None if root) + parent: Option>, + /// Size of the sandbox this snapshot was taken of + sandbox_size: usize, } impl SharedMemorySnapshot { - /// Take a snapshot of the memory in `shared_mem`, then create a new - /// instance of `Self` with the snapshot stored therein. + /// Take a snapshot of memory in `shared_mem` assuming `dirty_pages_bitmap` are the only + /// changed pages since `parent` snapshot was taken. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn new(shared_mem: &mut S, mapped_rgns: u64) -> Result { - // TODO: Track dirty pages instead of copying entire memory - let snapshot = shared_mem.with_exclusivity(|e| e.copy_all_to_vec())??; + pub(super) fn new( + shared_mem: &mut S, + dirty_pages_bitmap: &[u64], + mapped_rgns: u64, + parent: Option>, + ) -> Result { + let data = shared_mem.with_exclusivity(|e| -> Result>> { + let mut snapshot = HashMap::new(); + bit_index_iterator(dirty_pages_bitmap).try_for_each(|idx| { + let mut page = vec![0u8; PAGE_SIZE_USIZE]; + e.copy_to_slice(&mut page, idx * PAGE_SIZE_USIZE)?; + snapshot.insert(idx, page); + crate::Result::Ok(()) + })?; + Ok(snapshot) + })??; Ok(Self { - snapshot, + data, mapped_rgns, + parent, + sandbox_size: shared_mem.mem_size(), }) } - /// Take another snapshot of the internally-stored `SharedMemory`, - /// then store it internally. + /// Restore shared memory to the state it was in when `self` was taken, + /// assuming `most_recent_snapshot` is the most recent snapshot `shared_mem` was + /// in and `current_dirty_pages` are the pages that have been modified since. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + pub(super) fn restore_from_snapshot( + self: &Arc, + shared_mem: &mut S, + current_dirty_pages: &[u64], + most_recent_snapshot: &Option>, + ) -> Result { + if self.sandbox_size() != shared_mem.mem_size() { + return Err(crate::new_error!( + "Snapshot size does not match current memory size: {} != {}", + self.sandbox_size(), + shared_mem.mem_size() + )); + } - pub(super) fn replace_snapshot(&mut self, shared_mem: &mut S) -> Result<()> { - self.snapshot = shared_mem.with_exclusivity(|e| e.copy_all_to_vec())??; - Ok(()) - } + let mut pages_to_restore = HashSet::new(); + for page_num in bit_index_iterator(current_dirty_pages) { + pages_to_restore.insert(page_num); + } + + if let Some(most_recent) = most_recent_snapshot { + // Check if we're restoring to the same snapshot state the sandbox was most recently in + if Arc::ptr_eq(self, most_recent) { + // No need to collect additional pages - just use current dirty pages + } + // Check if we're "rolling back" to a previous state, i.e. `self` is older than `most_recent` (`self` is ancestor of `most_recent`) + // We need to restore all pages that exist in `most_recent` and its ancestors up to `self` (exclusive) + else if self.is_ancestor_of(most_recent) { + let mut current_snapshot = Some(most_recent); + + while let Some(snapshot) = current_snapshot { + if Arc::ptr_eq(snapshot, self) { + break; + } + + for page_num in snapshot.data.keys() { + pages_to_restore.insert(*page_num); + } + + current_snapshot = snapshot.parent.as_ref(); + } + // Check if we're "fast forwarding" to a "newer" state, i.e. `self` is newer than `most_recent` (`most_recent` is ancestor of `self`) + // We need to restore all pages that exist in `self` and its ancestors up to `most_recent` (exclusive) + } else if most_recent.is_ancestor_of(self) { + let mut current_snapshot = Some(self); + + while let Some(snapshot) = current_snapshot { + if Arc::ptr_eq(snapshot, most_recent) { + break; + } + + for page_num in snapshot.data.keys() { + pages_to_restore.insert(*page_num); + } + + current_snapshot = snapshot.parent.as_ref(); + } + } else { + // Neither is ancestor of the other - they're on different branches + // This is not supported for now + return Err(crate::new_error!( + "Cannot restore between snapshots on different branches" + )); + } + } else { + // No previous snapshots exist, meaning we're restoring a fresh sandbox to this snapshot. + // We need to restore pages that exist in this snapshot and all its ancestors. + let mut current_snapshot = Some(self); + + while let Some(snapshot) = current_snapshot { + for page_num in snapshot.data.keys() { + pages_to_restore.insert(*page_num); + } + + current_snapshot = snapshot.parent.as_ref(); + } + } + + // Restore all collected pages + shared_mem.with_exclusivity(|e| -> Result<()> { + for page_num in pages_to_restore { + let offset = page_num * PAGE_SIZE_USIZE; + + // Search backward through snapshots to find the page + if let Some(page_data) = self.find_page_in_snapshots(page_num) { + // Restore from snapshot + // # Safety: We don't want to dirty the pages we restore + unsafe { e.restore_copy_from_slice(page_data, offset)? }; + } else { + // Zero the page (return to initial state) + // # Safety: We don't want to dirty the pages we restore + unsafe { e.restore_zero_fill(offset, PAGE_SIZE_USIZE)? }; + } + } + Ok(()) + })??; - /// Copy the memory from the internally-stored memory snapshot - /// into the internally-stored `SharedMemory` - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn restore_from_snapshot(&self, shared_mem: &mut S) -> Result { - shared_mem.with_exclusivity(|e| e.copy_from_slice(self.snapshot.as_slice(), 0))??; Ok(self.mapped_rgns) } - /// Return the size of the snapshot in bytes. + /// Check if this snapshot is an ancestor of the other snapshot + /// (i.e., the other snapshot was taken after this one in the same chain) + fn is_ancestor_of(self: &Arc, other: &Arc) -> bool { + let mut current = other.parent.as_ref(); + + while let Some(snapshot) = current { + if Arc::ptr_eq(self, snapshot) { + return true; + } + current = snapshot.parent.as_ref(); + } + + false + } + + /// Search backward through the snapshot chain to find a page + fn find_page_in_snapshots(&self, page_num: usize) -> Option<&Vec> { + // Check this snapshot first + if let Some(page_data) = self.data.get(&page_num) { + return Some(page_data); + } + + // Check parent snapshots recursively + if let Some(parent) = &self.parent { + return parent.find_page_in_snapshots(page_num); + } + + None + } + + /// Return the size sandbox this snapshot was taken from. #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn mem_size(&self) -> usize { - self.snapshot.len() + pub(super) fn sandbox_size(&self) -> usize { + self.sandbox_size } } #[cfg(test)] mod tests { + use std::sync::Arc; + use hyperlight_common::mem::PAGE_SIZE_USIZE; use crate::mem::shared_mem::ExclusiveSharedMemory; #[test] - fn restore_replace() { - let mut data1 = vec![b'a', b'b', b'c']; - data1.resize_with(PAGE_SIZE_USIZE, || 0); - let data2 = data1.iter().map(|b| b + 1).collect::>(); + fn test_fresh_sandbox_restoration() { + // Test restoring a fresh sandbox to a given snapshot + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 3).unwrap(); + + // Set up initial data in three pages + let page0_data = vec![b'A'; PAGE_SIZE_USIZE]; + let page1_data = vec![b'B'; PAGE_SIZE_USIZE]; + let page2_data = vec![b'C'; PAGE_SIZE_USIZE]; + + gm.copy_from_slice(&page0_data, 0).unwrap(); + gm.copy_from_slice(&page1_data, PAGE_SIZE_USIZE).unwrap(); + gm.copy_from_slice(&page2_data, PAGE_SIZE_USIZE * 2) + .unwrap(); + + // Take snapshot with pages 0 and 2 dirty + let dirty_bitmap = [0b101]; + let snapshot = + Arc::new(super::SharedMemorySnapshot::new(&mut gm, &dirty_bitmap, 0, None).unwrap()); + assert_eq!(snapshot.data.len(), 2); + assert!(snapshot.data.contains_key(&0)); + assert!(!snapshot.data.contains_key(&1)); + assert!(snapshot.data.contains_key(&2)); + + // Create fresh sandbox + let mut fresh_gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 3).unwrap(); + + // Restore fresh sandbox to snapshot state + snapshot + .restore_from_snapshot(&mut fresh_gm, &[], &None) + .unwrap(); + + // Verify only pages 0 and 2 were restored, page 1 remains zero + let restored_data = fresh_gm.copy_all_to_vec().unwrap(); + assert_eq!(&restored_data[0..PAGE_SIZE_USIZE], &page0_data); + assert_eq!( + &restored_data[PAGE_SIZE_USIZE..PAGE_SIZE_USIZE * 2], + &vec![0u8; PAGE_SIZE_USIZE] + ); + assert_eq!( + &restored_data[PAGE_SIZE_USIZE * 2..PAGE_SIZE_USIZE * 3], + &page2_data + ); + } + + #[test] + fn test_rollback_to_ancestor() { + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 4).unwrap(); + + // Initial state - modify page 0 + let page0_v1 = vec![b'1'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page0_v1, 0).unwrap(); + let snapshot1 = + Arc::new(super::SharedMemorySnapshot::new(&mut gm, &[0b1], 0, None).unwrap()); + assert_eq!(snapshot1.data.len(), 1); + + // Modify page 1, take snapshot 2 + let page1_v1 = vec![b'2'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page1_v1, PAGE_SIZE_USIZE).unwrap(); + let snapshot2 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[0b10], 0, Some(snapshot1.clone())).unwrap(), + ); + assert_eq!(snapshot2.data.len(), 1); + + // Modify page 2, take snapshot 3 + let page2_v1 = vec![b'3'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page2_v1, PAGE_SIZE_USIZE * 2).unwrap(); + let snapshot3 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[0b100], 0, Some(snapshot2.clone())) + .unwrap(), + ); + assert_eq!(snapshot3.data.len(), 1); + + // Make additional changes to page 3 (current dirty) + let page3_v1 = vec![b'4'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page3_v1, PAGE_SIZE_USIZE * 3).unwrap(); + let current_dirty = [0b1000]; // page 3 is dirty + + // Rollback to snapshot1 (most recent is snapshot3) + snapshot1 + .restore_from_snapshot(&mut gm, ¤t_dirty, &Some(snapshot3)) + .unwrap(); + + let restored_data = gm.copy_all_to_vec().unwrap(); + + // Page 0 should be restored to snapshot1's value + assert_eq!(&restored_data[0..PAGE_SIZE_USIZE], &page0_v1); + // Pages 1, 2, 3 should be zeroed (not in snapshot1 or its ancestors) + assert_eq!( + &restored_data[PAGE_SIZE_USIZE..PAGE_SIZE_USIZE * 2], + &vec![0u8; PAGE_SIZE_USIZE] + ); + assert_eq!( + &restored_data[PAGE_SIZE_USIZE * 2..PAGE_SIZE_USIZE * 3], + &vec![0u8; PAGE_SIZE_USIZE] + ); + assert_eq!( + &restored_data[PAGE_SIZE_USIZE * 3..PAGE_SIZE_USIZE * 4], + &vec![0u8; PAGE_SIZE_USIZE] + ); + } + + #[test] + fn test_rollback_then_rollforward() { + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 4).unwrap(); + + // Helper function to verify page contents + let verify_page = |gm: &ExclusiveSharedMemory, page_idx: usize, expected: u8| { + let data = gm.copy_all_to_vec().unwrap(); + let page_start = page_idx * PAGE_SIZE_USIZE; + let page_end = page_start + PAGE_SIZE_USIZE; + assert_eq!(data[page_start..page_end], vec![expected; PAGE_SIZE_USIZE]); + }; + + let verify_page_zero = |gm: &ExclusiveSharedMemory, page_idx: usize| { + let data = gm.copy_all_to_vec().unwrap(); + let page_start = page_idx * PAGE_SIZE_USIZE; + let page_end = page_start + PAGE_SIZE_USIZE; + assert_eq!(data[page_start..page_end], vec![0u8; PAGE_SIZE_USIZE]); + }; + + // Initial state: all pages zero + verify_page_zero(&gm, 0); + verify_page_zero(&gm, 1); + verify_page_zero(&gm, 2); + verify_page_zero(&gm, 3); + + // === SNAPSHOT 1 === + // Modify page 0 to 'A' and take snapshot1 + let page0_data = vec![b'A'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page0_data, 0).unwrap(); + let snapshot1 = + Arc::new(super::SharedMemorySnapshot::new(&mut gm, &[0b1], 0, None).unwrap()); + + // Verify snapshot1 contains only page 0 + assert_eq!(snapshot1.data.len(), 1); + assert!(snapshot1.data.contains_key(&0)); + assert_eq!(snapshot1.data[&0], page0_data); + + // Verify current memory state + verify_page(&gm, 0, b'A'); + verify_page_zero(&gm, 1); + verify_page_zero(&gm, 2); + verify_page_zero(&gm, 3); + + // === SNAPSHOT 2 === + // Modify page 1 to 'B' and take snapshot2 + let page1_data = vec![b'B'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page1_data, PAGE_SIZE_USIZE).unwrap(); + let snapshot2 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[0b10], 0, Some(snapshot1.clone())).unwrap(), + ); + + // Verify snapshot2 contains only page 1 (page 0 unchanged) + assert_eq!(snapshot2.data.len(), 1); + assert!(snapshot2.data.contains_key(&1)); + assert_eq!(snapshot2.data[&1], page1_data); + + // Verify current memory state + verify_page(&gm, 0, b'A'); + verify_page(&gm, 1, b'B'); + verify_page_zero(&gm, 2); + verify_page_zero(&gm, 3); + + // === SNAPSHOT 3 === + // Modify page 2 to 'C' and take snapshot3 + let page2_data = vec![b'C'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page2_data, PAGE_SIZE_USIZE * 2) + .unwrap(); + let snapshot3 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[0b100], 0, Some(snapshot2.clone())) + .unwrap(), + ); + + // Verify snapshot3 contains only page 2 + assert_eq!(snapshot3.data.len(), 1); + assert!(snapshot3.data.contains_key(&2)); + assert_eq!(snapshot3.data[&2], page2_data); + + // Verify current memory state + verify_page(&gm, 0, b'A'); + verify_page(&gm, 1, b'B'); + verify_page(&gm, 2, b'C'); + verify_page_zero(&gm, 3); + + // Make some additional changes (dirty page 3 to 'D') + let page3_data = vec![b'D'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page3_data, PAGE_SIZE_USIZE * 3) + .unwrap(); + let current_dirty = [0b1000]; // page 3 is dirty + + // Verify current memory state before restoration + verify_page(&gm, 0, b'A'); + verify_page(&gm, 1, b'B'); + verify_page(&gm, 2, b'C'); + verify_page(&gm, 3, b'D'); + + // === RESTORE TO SNAPSHOT 2 (rollback) === + snapshot2 + .restore_from_snapshot(&mut gm, ¤t_dirty, &Some(snapshot3.clone())) + .unwrap(); + + // After rollback to snapshot2: + // - Page 0: 'A' (from snapshot1, ancestor of snapshot2) + // - Page 1: 'B' (from snapshot2) + // - Page 2: zeroed (was dirty after snapshot2, not in snapshot2 or ancestors) + // - Page 3: zeroed (was dirty, not in snapshot2 or ancestors) + verify_page(&gm, 0, b'A'); + verify_page(&gm, 1, b'B'); + verify_page_zero(&gm, 2); + verify_page_zero(&gm, 3); + + // === RESTORE TO SNAPSHOT 1 (further rollback) === + snapshot1 + .restore_from_snapshot(&mut gm, &[], &Some(snapshot2.clone())) + .unwrap(); + + // After rollback to snapshot1: + // - Page 0: 'A' (from snapshot1) + // - Page 1: zeroed (was dirty after snapshot1, not in snapshot1 or ancestors) + // - Page 2: remains zero + // - Page 3: remains zero + verify_page(&gm, 0, b'A'); + verify_page_zero(&gm, 1); + verify_page_zero(&gm, 2); + verify_page_zero(&gm, 3); + + // === RESTORE TO SNAPSHOT 2 (forward) === + snapshot2 + .restore_from_snapshot(&mut gm, &[], &Some(snapshot1.clone())) + .unwrap(); + + // After fast-forward to snapshot2: + // - Page 0: 'A' (from snapshot1, ancestor) + // - Page 1: 'B' (from snapshot2) + // - Page 2: remains zero + // - Page 3: remains zero + verify_page(&gm, 0, b'A'); + verify_page(&gm, 1, b'B'); + verify_page_zero(&gm, 2); + verify_page_zero(&gm, 3); + + // === RESTORE TO SNAPSHOT 3 (forward) === + snapshot3 + .restore_from_snapshot(&mut gm, &[], &Some(snapshot2.clone())) + .unwrap(); + + // After fast-forward to snapshot3: + // - Page 0: 'A' (from snapshot1, ancestor) + // - Page 1: 'B' (from snapshot2, ancestor) + // - Page 2: 'C' (from snapshot3) + // - Page 3: remains zero + verify_page(&gm, 0, b'A'); + verify_page(&gm, 1, b'B'); + verify_page(&gm, 2, b'C'); + verify_page_zero(&gm, 3); + } + + #[test] + fn test_current_dirty_pages_restoration() { + // Test that current dirty pages are properly restored + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 2).unwrap(); + + // Set up page 0 and take snapshot + let page0_snapshot_data = vec![b'S'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page0_snapshot_data, 0).unwrap(); + let snapshot = Arc::new(super::SharedMemorySnapshot::new(&mut gm, &[1], 0, None).unwrap()); + + // Modify both pages after snapshot + let page0_modified = vec![b'M'; PAGE_SIZE_USIZE]; + let page1_modified = vec![b'N'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page0_modified, 0).unwrap(); + gm.copy_from_slice(&page1_modified, PAGE_SIZE_USIZE) + .unwrap(); + + // Current dirty pages: both page 0 and 1 + let current_dirty = [3u64]; // pages 0 and 1 are dirty + + // Restore to snapshot + snapshot + .restore_from_snapshot(&mut gm, ¤t_dirty, &None) + .unwrap(); + + let restored_data = gm.copy_all_to_vec().unwrap(); + + // Page 0 should be restored to snapshot value + assert_eq!(&restored_data[0..PAGE_SIZE_USIZE], &page0_snapshot_data); + // Page 1 should be zeroed (wasn't in snapshot) + assert_eq!( + &restored_data[PAGE_SIZE_USIZE..PAGE_SIZE_USIZE * 2], + &vec![0u8; PAGE_SIZE_USIZE] + ); + } + + #[test] + fn test_snapshot_chain_search() { + // Test that page search works through snapshot chain + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 2).unwrap(); + + // Snapshot 1: page 0 = 'A' + let page0_v1 = vec![b'A'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page0_v1, 0).unwrap(); + let snapshot1 = Arc::new(super::SharedMemorySnapshot::new(&mut gm, &[1], 0, None).unwrap()); + + // Snapshot 2: page 1 = 'B' (page 0 unchanged, not stored again) + let page1_v1 = vec![b'B'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page1_v1, PAGE_SIZE_USIZE).unwrap(); + let snapshot2 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[2], 0, Some(snapshot1.clone())).unwrap(), + ); + + // Snapshot 3: page 0 = 'C' (overwrites page 0) + let page0_v2 = vec![b'C'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page0_v2, 0).unwrap(); + let snapshot3 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[1], 0, Some(snapshot2.clone())).unwrap(), + ); + + // Clear memory and restore to snapshot3 + gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 2).unwrap(); + snapshot3 + .restore_from_snapshot(&mut gm, &[], &None) + .unwrap(); + + let restored_data = gm.copy_all_to_vec().unwrap(); + + // Page 0 should have snapshot3's value (most recent) + assert_eq!(&restored_data[0..PAGE_SIZE_USIZE], &page0_v2); + // Page 1 should have snapshot2's value (found in parent) + assert_eq!( + &restored_data[PAGE_SIZE_USIZE..PAGE_SIZE_USIZE * 2], + &page1_v1 + ); + } + + #[test] + fn test_ancestor_relationship() { let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE).unwrap(); - gm.copy_from_slice(data1.as_slice(), 0).unwrap(); - let mut snap = super::SharedMemorySnapshot::new(&mut gm, 0).unwrap(); - { - // after the first snapshot is taken, make sure gm has the equivalent - // of data1 - assert_eq!(data1, gm.copy_all_to_vec().unwrap()); - } - { - // modify gm with data2 rather than data1 and restore from - // snapshot. we should have the equivalent of data1 again - gm.copy_from_slice(data2.as_slice(), 0).unwrap(); - assert_eq!(data2, gm.copy_all_to_vec().unwrap()); - snap.restore_from_snapshot(&mut gm).unwrap(); - assert_eq!(data1, gm.copy_all_to_vec().unwrap()); - } - { - // modify gm with data2, then retake the snapshot and restore - // from the new snapshot. we should have the equivalent of data2 - gm.copy_from_slice(data2.as_slice(), 0).unwrap(); - assert_eq!(data2, gm.copy_all_to_vec().unwrap()); - snap.replace_snapshot(&mut gm).unwrap(); - assert_eq!(data2, gm.copy_all_to_vec().unwrap()); - snap.restore_from_snapshot(&mut gm).unwrap(); - assert_eq!(data2, gm.copy_all_to_vec().unwrap()); - } + let snapshot1 = Arc::new(super::SharedMemorySnapshot::new(&mut gm, &[], 0, None).unwrap()); + let snapshot2 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[], 0, Some(snapshot1.clone())).unwrap(), + ); + let snapshot3 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[], 0, Some(snapshot2.clone())).unwrap(), + ); + + // Test ancestor relationships + assert!(snapshot1.is_ancestor_of(&snapshot2)); + assert!(snapshot1.is_ancestor_of(&snapshot3)); + assert!(snapshot2.is_ancestor_of(&snapshot3)); + + // Test non-ancestor relationships + assert!(!snapshot2.is_ancestor_of(&snapshot1)); + assert!(!snapshot3.is_ancestor_of(&snapshot1)); + assert!(!snapshot3.is_ancestor_of(&snapshot2)); + + // Test self-relationship + assert!(!snapshot1.is_ancestor_of(&snapshot1)); + } + + #[test] + fn test_different_branches_error() { + // Test error when trying to restore between snapshots on different branches + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE).unwrap(); + + // Common ancestor + let ancestor = Arc::new(super::SharedMemorySnapshot::new(&mut gm, &[], 0, None).unwrap()); + + // Two different branches + let branch1 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[], 0, Some(ancestor.clone())).unwrap(), + ); + let branch2 = Arc::new( + super::SharedMemorySnapshot::new(&mut gm, &[], 0, Some(ancestor.clone())).unwrap(), + ); + + // Trying to restore between different branches should error + let result = branch1.restore_from_snapshot(&mut gm, &[], &Some(branch2)); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("different branches") + ); + } + + #[test] + fn test_restore_different_size() { + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 2).unwrap(); + gm.copy_from_slice(&[1u8; PAGE_SIZE_USIZE], 0).unwrap(); + let snapshot = + Arc::new(super::SharedMemorySnapshot::new(&mut gm, &[0b1], 0, None).unwrap()); + + let mut bigger_gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 3).unwrap(); + + let result = snapshot.restore_from_snapshot(&mut bigger_gm, &[], &Some(snapshot.clone())); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Snapshot size does not match current memory size") + ); + } + + #[test] + fn test_restore_to_same_snapshot() { + // Test restoring to the same snapshot that is currently the most recent + let mut gm = ExclusiveSharedMemory::new(PAGE_SIZE_USIZE * 2).unwrap(); + + let page0_data = vec![b'X'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page0_data, 0).unwrap(); + + // Take a snapshot + let snapshot = Arc::new(super::SharedMemorySnapshot::new(&mut gm, &[1], 0, None).unwrap()); + + // Modify memory after snapshot + let page1_data = vec![b'Y'; PAGE_SIZE_USIZE]; + gm.copy_from_slice(&page1_data, PAGE_SIZE_USIZE).unwrap(); + + // Current dirty pages: page 1 is dirty + let current_dirty = [2u64]; // page 1 is dirty + + // Restore to the same snapshot that is the most recent + snapshot + .restore_from_snapshot(&mut gm, ¤t_dirty, &Some(snapshot.clone())) + .unwrap(); + + let restored_data = gm.copy_all_to_vec().unwrap(); + + // Page 0 should remain unchanged (it was the snapshot data) + assert_eq!(&restored_data[0..PAGE_SIZE_USIZE], &page0_data); + // Page 1 should be zeroed (it was dirty but not in the snapshot) + assert_eq!( + &restored_data[PAGE_SIZE_USIZE..PAGE_SIZE_USIZE * 2], + &vec![0u8; PAGE_SIZE_USIZE] + ); } } diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 6208d8f85..99cb40d22 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -36,11 +36,12 @@ use crate::func::{ParameterTuple, SupportedReturnType}; use crate::hypervisor::handlers::DbgMemAccessHandlerWrapper; use crate::hypervisor::handlers::{MemAccessHandlerCaller, OutBHandlerCaller}; use crate::hypervisor::{Hypervisor, InterruptHandle}; +use crate::mem::bitmap::bitmap_union; #[cfg(unix)] use crate::mem::memory_region::MemoryRegionType; use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; use crate::mem::ptr::RawPtr; -use crate::mem::shared_mem::HostSharedMemory; +use crate::mem::shared_mem::{HostSharedMemory, SharedMemory}; use crate::metrics::maybe_time_and_emit_guest_call; use crate::{HyperlightError, Result, log_then_return}; @@ -95,17 +96,39 @@ impl MultiUseSandbox { /// Create a snapshot of the current state of the sandbox's memory. #[instrument(err(Debug), skip_all, parent = Span::current())] pub fn snapshot(&mut self) -> Result { - let snapshot = self.mem_mgr.unwrap_mgr_mut().snapshot()?; + let host_dirty_pages = self + .get_mgr_wrapper_mut() + .unwrap_mgr_mut() + .get_shared_mem_mut() + .with_exclusivity(|e| e.get_and_clear_dirty_pages())??; + let vm_dirty_pages = self.vm.get_and_clear_dirty_pages()?; + + let dirty_pages_bitmap = bitmap_union(&vm_dirty_pages, &host_dirty_pages); + + let snapshot = self + .mem_mgr + .unwrap_mgr_mut() + .snapshot(&dirty_pages_bitmap)?; + Ok(Snapshot { inner: snapshot }) } /// Restore the sandbox's memory to the state captured in the given snapshot. #[instrument(err(Debug), skip_all, parent = Span::current())] pub fn restore(&mut self, snapshot: &Snapshot) -> Result<()> { + let host_dirty_pages = self + .get_mgr_wrapper_mut() + .unwrap_mgr_mut() + .get_shared_mem_mut() + .with_exclusivity(|e| e.get_and_clear_dirty_pages())??; + let vm_dirty_pages = self.vm.get_and_clear_dirty_pages()?; + + let dirty_pages_bitmap = bitmap_union(&vm_dirty_pages, &host_dirty_pages); + let rgns_to_unmap = self .mem_mgr .unwrap_mgr_mut() - .restore_snapshot(&snapshot.inner)?; + .restore_snapshot(&snapshot.inner, &dirty_pages_bitmap)?; unsafe { self.vm.unmap_regions(rgns_to_unmap)? }; Ok(()) } diff --git a/src/hyperlight_host/src/sandbox/snapshot.rs b/src/hyperlight_host/src/sandbox/snapshot.rs index d91f52437..c00aa4487 100644 --- a/src/hyperlight_host/src/sandbox/snapshot.rs +++ b/src/hyperlight_host/src/sandbox/snapshot.rs @@ -14,11 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ +use std::sync::Arc; + use crate::mem::shared_mem_snapshot::SharedMemorySnapshot; /// A snapshot capturing the state of the memory in a `MultiUseSandbox`. #[derive(Clone)] pub struct Snapshot { - /// TODO: Use Arc - pub(crate) inner: SharedMemorySnapshot, + pub(crate) inner: Arc, }