diff --git a/Cargo.lock b/Cargo.lock index 7ae85a17d..751c2a3bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1388,6 +1388,7 @@ name = "hyperlight-guest" version = "0.8.0" dependencies = [ "anyhow", + "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", "serde_json", diff --git a/src/hyperlight_common/src/flatbuffer_wrappers/function_call.rs b/src/hyperlight_common/src/flatbuffer_wrappers/function_call.rs index 67998fbbe..056ced8e0 100644 --- a/src/hyperlight_common/src/flatbuffer_wrappers/function_call.rs +++ b/src/hyperlight_common/src/flatbuffer_wrappers/function_call.rs @@ -18,7 +18,7 @@ use alloc::string::{String, ToString}; use alloc::vec::Vec; use anyhow::{Error, Result, bail}; -use flatbuffers::{WIPOffset, size_prefixed_root}; +use flatbuffers::{FlatBufferBuilder, WIPOffset, size_prefixed_root}; #[cfg(feature = "tracing")] use tracing::{Span, instrument}; @@ -72,214 +72,136 @@ impl FunctionCall { pub fn function_call_type(&self) -> FunctionCallType { self.function_call_type.clone() } -} - -#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))] -pub fn validate_guest_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> { - let guest_function_call_fb = size_prefixed_root::(function_call_buffer) - .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?; - match guest_function_call_fb.function_call_type() { - FbFunctionCallType::guest => Ok(()), - other => { - bail!("Invalid function call type: {:?}", other); - } - } -} - -#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))] -pub fn validate_host_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> { - let host_function_call_fb = size_prefixed_root::(function_call_buffer) - .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?; - match host_function_call_fb.function_call_type() { - FbFunctionCallType::host => Ok(()), - other => { - bail!("Invalid function call type: {:?}", other); - } - } -} - -impl TryFrom<&[u8]> for FunctionCall { - type Error = Error; - #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))] - fn try_from(value: &[u8]) -> Result { - let function_call_fb = size_prefixed_root::(value) - .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?; - let function_name = function_call_fb.function_name(); - let function_call_type = match function_call_fb.function_call_type() { - FbFunctionCallType::guest => FunctionCallType::Guest, - FbFunctionCallType::host => FunctionCallType::Host, - other => { - bail!("Invalid function call type: {:?}", other); - } - }; - let expected_return_type = function_call_fb.expected_return_type().try_into()?; - let parameters = function_call_fb - .parameters() - .map(|v| { - v.iter() - .map(|p| p.try_into()) - .collect::>>() - }) - .transpose()?; + /// Encodes self into the given builder and returns the encoded data. + /// + /// # Notes + /// + /// The builder should not be reused after a call to encode, since this function + /// does not reset the state of the builder. If you want to reuse the builder, + /// you'll need to reset it first. + pub fn encode<'a>(&self, builder: &'a mut FlatBufferBuilder) -> &'a [u8] { + let function_name = builder.create_string(&self.function_name); - Ok(Self { - function_name: function_name.to_string(), - parameters, - function_call_type, - expected_return_type, - }) - } -} - -impl TryFrom for Vec { - type Error = Error; - #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))] - fn try_from(value: FunctionCall) -> Result> { - let mut builder = flatbuffers::FlatBufferBuilder::new(); - let function_name = builder.create_string(&value.function_name); - - let function_call_type = match value.function_call_type { + let function_call_type = match self.function_call_type { FunctionCallType::Guest => FbFunctionCallType::guest, FunctionCallType::Host => FbFunctionCallType::host, }; - let expected_return_type = value.expected_return_type.into(); - - let parameters = match &value.parameters { - Some(p) => { - let num_items = p.len(); - let mut parameters: Vec> = Vec::with_capacity(num_items); + let expected_return_type = self.expected_return_type.into(); - for param in p { - match param { + let parameters = match &self.parameters { + Some(p) if !p.is_empty() => { + let parameter_offsets: Vec> = p + .iter() + .map(|param| match param { ParameterValue::Int(i) => { - let hlint = hlint::create(&mut builder, &hlintArgs { value: *i }); - let parameter = Parameter::create( - &mut builder, + let hlint = hlint::create(builder, &hlintArgs { value: *i }); + Parameter::create( + builder, &ParameterArgs { value_type: FbParameterValue::hlint, value: Some(hlint.as_union_value()), }, - ); - parameters.push(parameter); + ) } ParameterValue::UInt(ui) => { - let hluint = hluint::create(&mut builder, &hluintArgs { value: *ui }); - let parameter = Parameter::create( - &mut builder, + let hluint = hluint::create(builder, &hluintArgs { value: *ui }); + Parameter::create( + builder, &ParameterArgs { value_type: FbParameterValue::hluint, value: Some(hluint.as_union_value()), }, - ); - parameters.push(parameter); + ) } ParameterValue::Long(l) => { - let hllong = hllong::create(&mut builder, &hllongArgs { value: *l }); - let parameter = Parameter::create( - &mut builder, + let hllong = hllong::create(builder, &hllongArgs { value: *l }); + Parameter::create( + builder, &ParameterArgs { value_type: FbParameterValue::hllong, value: Some(hllong.as_union_value()), }, - ); - parameters.push(parameter); + ) } ParameterValue::ULong(ul) => { - let hlulong = - hlulong::create(&mut builder, &hlulongArgs { value: *ul }); - let parameter = Parameter::create( - &mut builder, + let hlulong = hlulong::create(builder, &hlulongArgs { value: *ul }); + Parameter::create( + builder, &ParameterArgs { value_type: FbParameterValue::hlulong, value: Some(hlulong.as_union_value()), }, - ); - parameters.push(parameter); + ) } ParameterValue::Float(f) => { - let hlfloat = hlfloat::create(&mut builder, &hlfloatArgs { value: *f }); - let parameter = Parameter::create( - &mut builder, + let hlfloat = hlfloat::create(builder, &hlfloatArgs { value: *f }); + Parameter::create( + builder, &ParameterArgs { value_type: FbParameterValue::hlfloat, value: Some(hlfloat.as_union_value()), }, - ); - parameters.push(parameter); + ) } ParameterValue::Double(d) => { - let hldouble = - hldouble::create(&mut builder, &hldoubleArgs { value: *d }); - let parameter = Parameter::create( - &mut builder, + let hldouble = hldouble::create(builder, &hldoubleArgs { value: *d }); + Parameter::create( + builder, &ParameterArgs { value_type: FbParameterValue::hldouble, value: Some(hldouble.as_union_value()), }, - ); - parameters.push(parameter); + ) } ParameterValue::Bool(b) => { - let hlbool: WIPOffset> = - hlbool::create(&mut builder, &hlboolArgs { value: *b }); - let parameter = Parameter::create( - &mut builder, + let hlbool = hlbool::create(builder, &hlboolArgs { value: *b }); + Parameter::create( + builder, &ParameterArgs { value_type: FbParameterValue::hlbool, value: Some(hlbool.as_union_value()), }, - ); - parameters.push(parameter); + ) } ParameterValue::String(s) => { - let hlstring = { - let val = builder.create_string(s.as_str()); - hlstring::create(&mut builder, &hlstringArgs { value: Some(val) }) - }; - let parameter = Parameter::create( - &mut builder, + let val = builder.create_string(s.as_str()); + let hlstring = + hlstring::create(builder, &hlstringArgs { value: Some(val) }); + Parameter::create( + builder, &ParameterArgs { value_type: FbParameterValue::hlstring, value: Some(hlstring.as_union_value()), }, - ); - parameters.push(parameter); + ) } ParameterValue::VecBytes(v) => { let vec_bytes = builder.create_vector(v); - let hlvecbytes = hlvecbytes::create( - &mut builder, + builder, &hlvecbytesArgs { value: Some(vec_bytes), }, ); - let parameter = Parameter::create( - &mut builder, + Parameter::create( + builder, &ParameterArgs { value_type: FbParameterValue::hlvecbytes, value: Some(hlvecbytes.as_union_value()), }, - ); - parameters.push(parameter); + ) } - } - } - parameters + }) + .collect(); + Some(builder.create_vector(¶meter_offsets)) } - None => Vec::new(), - }; - - let parameters = if !parameters.is_empty() { - Some(builder.create_vector(¶meters)) - } else { - None + _ => None, }; let function_call = FbFunctionCall::create( - &mut builder, + builder, &FbFunctionCallArgs { function_name: Some(function_name), parameters, @@ -288,9 +210,65 @@ impl TryFrom for Vec { }, ); builder.finish_size_prefixed(function_call, None); - let res = builder.finished_data().to_vec(); + builder.finished_data() + } +} - Ok(res) +#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))] +pub fn validate_guest_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> { + let guest_function_call_fb = size_prefixed_root::(function_call_buffer) + .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?; + match guest_function_call_fb.function_call_type() { + FbFunctionCallType::guest => Ok(()), + other => { + bail!("Invalid function call type: {:?}", other); + } + } +} + +#[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))] +pub fn validate_host_function_call_buffer(function_call_buffer: &[u8]) -> Result<()> { + let host_function_call_fb = size_prefixed_root::(function_call_buffer) + .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?; + match host_function_call_fb.function_call_type() { + FbFunctionCallType::host => Ok(()), + other => { + bail!("Invalid function call type: {:?}", other); + } + } +} + +impl TryFrom<&[u8]> for FunctionCall { + type Error = Error; + #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))] + fn try_from(value: &[u8]) -> Result { + let function_call_fb = size_prefixed_root::(value) + .map_err(|e| anyhow::anyhow!("Error reading function call buffer: {:?}", e))?; + let function_name = function_call_fb.function_name(); + let function_call_type = match function_call_fb.function_call_type() { + FbFunctionCallType::guest => FunctionCallType::Guest, + FbFunctionCallType::host => FunctionCallType::Host, + other => { + bail!("Invalid function call type: {:?}", other); + } + }; + let expected_return_type = function_call_fb.expected_return_type().try_into()?; + + let parameters = function_call_fb + .parameters() + .map(|v| { + v.iter() + .map(|p| p.try_into()) + .collect::>>() + }) + .transpose()?; + + Ok(Self { + function_name: function_name.to_string(), + parameters, + function_call_type, + expected_return_type, + }) } } @@ -303,7 +281,8 @@ mod tests { #[test] fn read_from_flatbuffer() -> Result<()> { - let test_data: Vec = FunctionCall::new( + let mut builder = FlatBufferBuilder::new(); + let test_data = FunctionCall::new( "PrintTwelveArgs".to_string(), Some(vec![ ParameterValue::String("1".to_string()), @@ -322,10 +301,9 @@ mod tests { FunctionCallType::Guest, ReturnType::Int, ) - .try_into() - .unwrap(); + .encode(&mut builder); - let function_call = FunctionCall::try_from(test_data.as_slice())?; + let function_call = FunctionCall::try_from(test_data)?; assert_eq!(function_call.function_name, "PrintTwelveArgs"); assert!(function_call.parameters.is_some()); let parameters = function_call.parameters.unwrap(); diff --git a/src/hyperlight_common/src/flatbuffer_wrappers/function_types.rs b/src/hyperlight_common/src/flatbuffer_wrappers/function_types.rs index b381da592..16bd91bc0 100644 --- a/src/hyperlight_common/src/flatbuffer_wrappers/function_types.rs +++ b/src/hyperlight_common/src/flatbuffer_wrappers/function_types.rs @@ -181,7 +181,7 @@ impl TryFrom> for ParameterValue { ParameterValue::String(hlstring.value().unwrap_or_default().to_string()) }), FbParameterValue::hlvecbytes => param.value_as_hlvecbytes().map(|hlvecbytes| { - ParameterValue::VecBytes(hlvecbytes.value().unwrap_or_default().iter().collect()) + ParameterValue::VecBytes(hlvecbytes.value().unwrap_or_default().bytes().to_vec()) }), other => { bail!("Unexpected flatbuffer parameter value type: {:?}", other); diff --git a/src/hyperlight_common/src/flatbuffer_wrappers/util.rs b/src/hyperlight_common/src/flatbuffer_wrappers/util.rs index ba3645b94..96ab16f1e 100644 --- a/src/hyperlight_common/src/flatbuffer_wrappers/util.rs +++ b/src/hyperlight_common/src/flatbuffer_wrappers/util.rs @@ -18,6 +18,7 @@ use alloc::vec::Vec; use flatbuffers::FlatBufferBuilder; +use crate::flatbuffer_wrappers::function_types::ParameterValue; use crate::flatbuffers::hyperlight::generated::{ FunctionCallResult as FbFunctionCallResult, FunctionCallResultArgs as FbFunctionCallResultArgs, ReturnValue as FbReturnValue, hlbool as Fbhlbool, hlboolArgs as FbhlboolArgs, @@ -169,3 +170,350 @@ impl FlatbufferSerializable for bool { } } } + +/// Estimates the required buffer capacity for encoding a FunctionCall with the given parameters. +/// This helps avoid reallocation during FlatBuffer encoding when passing large slices and strings. +/// +/// The function aims to be lightweight and fast and run in O(1) as long as the number of parameters is limited +/// (which it is since hyperlight only currently supports up to 12). +/// +/// Note: This estimates the capacity needed for the inner vec inside a FlatBufferBuilder. It does not +/// necessarily match the size of the final encoded buffer. The estimation always rounds up to the +/// nearest power of two to match FlatBufferBuilder's allocation strategy. +/// +/// The estimations are numbers used are empirically derived based on the tests below and vaguely based +/// on https://flatbuffers.dev/internals/ and https://github.com/dvidelabs/flatcc/blob/f064cefb2034d1e7407407ce32a6085c322212a7/doc/binary-format.md#flatbuffers-binary-format +#[inline] // allow cross-crate inlining (for hyperlight-host calls) +pub fn estimate_flatbuffer_capacity(function_name: &str, args: &[ParameterValue]) -> usize { + let mut estimated_capacity = 20; + + // Function name overhead + estimated_capacity += function_name.len() + 12; + + // Parameters vector overhead + estimated_capacity += 12 + args.len() * 6; + + // Per-parameter overhead + for arg in args { + estimated_capacity += 16; // Base parameter structure + estimated_capacity += match arg { + ParameterValue::String(s) => s.len() + 20, + ParameterValue::VecBytes(v) => v.len() + 20, + ParameterValue::Int(_) | ParameterValue::UInt(_) => 16, + ParameterValue::Long(_) | ParameterValue::ULong(_) => 20, + ParameterValue::Float(_) => 16, + ParameterValue::Double(_) => 20, + ParameterValue::Bool(_) => 12, + }; + } + + // match how vec grows + estimated_capacity.next_power_of_two() +} + +#[cfg(test)] +mod tests { + use alloc::string::ToString; + use alloc::vec; + use alloc::vec::Vec; + + use super::*; + use crate::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; + use crate::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; + + /// Helper function to check that estimation is within reasonable bounds (±25%) + fn assert_estimation_accuracy( + function_name: &str, + args: Vec, + call_type: FunctionCallType, + return_type: ReturnType, + ) { + let estimated = estimate_flatbuffer_capacity(function_name, &args); + + let fc = FunctionCall::new( + function_name.to_string(), + Some(args), + call_type.clone(), + return_type, + ); + // Important that this FlatBufferBuilder is created with capacity 0 so it grows to its needed capacity + let mut builder = FlatBufferBuilder::new(); + let _buffer = fc.encode(&mut builder); + let actual = builder.collapse().0.capacity(); + + let lower_bound = (actual as f64 * 0.75) as usize; + let upper_bound = (actual as f64 * 1.25) as usize; + + assert!( + estimated >= lower_bound && estimated <= upper_bound, + "Estimation {} outside bounds [{}, {}] for actual size {} (function: {}, call_type: {:?}, return_type: {:?})", + estimated, + lower_bound, + upper_bound, + actual, + function_name, + call_type, + return_type + ); + } + + #[test] + fn test_estimate_no_parameters() { + assert_estimation_accuracy( + "simple_function", + vec![], + FunctionCallType::Guest, + ReturnType::Void, + ); + } + + #[test] + fn test_estimate_single_int_parameter() { + assert_estimation_accuracy( + "add_one", + vec![ParameterValue::Int(42)], + FunctionCallType::Guest, + ReturnType::Int, + ); + } + + #[test] + fn test_estimate_multiple_scalar_parameters() { + assert_estimation_accuracy( + "calculate", + vec![ + ParameterValue::Int(10), + ParameterValue::UInt(20), + ParameterValue::Long(30), + ParameterValue::ULong(40), + ParameterValue::Float(1.5), + ParameterValue::Double(2.5), + ParameterValue::Bool(true), + ], + FunctionCallType::Guest, + ReturnType::Double, + ); + } + + #[test] + fn test_estimate_string_parameters() { + assert_estimation_accuracy( + "process_strings", + vec![ + ParameterValue::String("hello".to_string()), + ParameterValue::String("world".to_string()), + ParameterValue::String("this is a longer string for testing".to_string()), + ], + FunctionCallType::Host, + ReturnType::String, + ); + } + + #[test] + fn test_estimate_very_long_string() { + let long_string = "a".repeat(1000); + assert_estimation_accuracy( + "process_long_string", + vec![ParameterValue::String(long_string)], + FunctionCallType::Guest, + ReturnType::String, + ); + } + + #[test] + fn test_estimate_vector_parameters() { + assert_estimation_accuracy( + "process_vectors", + vec![ + ParameterValue::VecBytes(vec![1, 2, 3, 4, 5]), + ParameterValue::VecBytes(vec![]), + ParameterValue::VecBytes(vec![0; 100]), + ], + FunctionCallType::Host, + ReturnType::VecBytes, + ); + } + + #[test] + fn test_estimate_mixed_parameters() { + assert_estimation_accuracy( + "complex_function", + vec![ + ParameterValue::String("test".to_string()), + ParameterValue::Int(42), + ParameterValue::VecBytes(vec![1, 2, 3, 4, 5]), + ParameterValue::Bool(true), + ParameterValue::Double(553.14159), + ParameterValue::String("another string".to_string()), + ParameterValue::Long(9223372036854775807), + ], + FunctionCallType::Guest, + ReturnType::VecBytes, + ); + } + + #[test] + fn test_estimate_large_function_name() { + let long_name = "very_long_function_name_that_exceeds_normal_lengths_for_testing_purposes"; + assert_estimation_accuracy( + long_name, + vec![ParameterValue::Int(1)], + FunctionCallType::Host, + ReturnType::Long, + ); + } + + #[test] + fn test_estimate_large_vector() { + let large_vector = vec![42u8; 10000]; + assert_estimation_accuracy( + "process_large_data", + vec![ParameterValue::VecBytes(large_vector)], + FunctionCallType::Guest, + ReturnType::Bool, + ); + } + + #[test] + fn test_estimate_all_parameter_types() { + assert_estimation_accuracy( + "comprehensive_test", + vec![ + ParameterValue::Int(i32::MIN), + ParameterValue::UInt(u32::MAX), + ParameterValue::Long(i64::MIN), + ParameterValue::ULong(u64::MAX), + ParameterValue::Float(f32::MIN), + ParameterValue::Double(f64::MAX), + ParameterValue::Bool(false), + ParameterValue::String("test string".to_string()), + ParameterValue::VecBytes(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + ], + FunctionCallType::Host, + ReturnType::ULong, + ); + } + + #[test] + fn test_different_function_call_types() { + assert_estimation_accuracy( + "guest_function", + vec![ParameterValue::String("guest call".to_string())], + FunctionCallType::Guest, + ReturnType::String, + ); + + assert_estimation_accuracy( + "host_function", + vec![ParameterValue::String("host call".to_string())], + FunctionCallType::Host, + ReturnType::String, + ); + } + + #[test] + fn test_different_return_types() { + let args = vec![ + ParameterValue::Int(42), + ParameterValue::String("test".to_string()), + ]; + + let void_est = estimate_flatbuffer_capacity("test_void", &args); + let int_est = estimate_flatbuffer_capacity("test_int", &args); + let string_est = estimate_flatbuffer_capacity("test_string", &args); + + assert!((void_est as i32 - int_est as i32).abs() < 10); + assert!((int_est as i32 - string_est as i32).abs() < 10); + + assert_estimation_accuracy( + "test_void", + args.clone(), + FunctionCallType::Guest, + ReturnType::Void, + ); + assert_estimation_accuracy( + "test_int", + args.clone(), + FunctionCallType::Guest, + ReturnType::Int, + ); + assert_estimation_accuracy( + "test_string", + args, + FunctionCallType::Guest, + ReturnType::String, + ); + } + + #[test] + fn test_estimate_many_large_vectors_and_strings() { + assert_estimation_accuracy( + "process_bulk_data", + vec![ + ParameterValue::String("Large string data: ".to_string() + &"x".repeat(2000)), + ParameterValue::VecBytes(vec![1u8; 5000]), + ParameterValue::String( + "Another large string with lots of content ".to_string() + &"y".repeat(3000), + ), + ParameterValue::VecBytes(vec![255u8; 7500]), + ParameterValue::String( + "Third massive string parameter ".to_string() + &"z".repeat(1500), + ), + ParameterValue::VecBytes(vec![128u8; 10000]), + ParameterValue::Int(42), + ParameterValue::String("Final large string ".to_string() + &"a".repeat(4000)), + ParameterValue::VecBytes(vec![64u8; 2500]), + ParameterValue::Bool(true), + ], + FunctionCallType::Host, + ReturnType::VecBytes, + ); + } + + #[test] + fn test_estimate_twenty_parameters() { + assert_estimation_accuracy( + "function_with_many_parameters", + vec![ + ParameterValue::Int(1), + ParameterValue::String("param2".to_string()), + ParameterValue::Bool(true), + ParameterValue::Float(3213.14), + ParameterValue::VecBytes(vec![1, 2, 3]), + ParameterValue::Long(1000000), + ParameterValue::Double(322.718), + ParameterValue::UInt(42), + ParameterValue::String("param9".to_string()), + ParameterValue::Bool(false), + ParameterValue::ULong(9999999999), + ParameterValue::VecBytes(vec![4, 5, 6, 7, 8]), + ParameterValue::Int(-100), + ParameterValue::Float(1.414), + ParameterValue::String("param15".to_string()), + ParameterValue::Double(1.732), + ParameterValue::Bool(true), + ParameterValue::VecBytes(vec![9, 10]), + ParameterValue::Long(-5000000), + ParameterValue::UInt(12345), + ], + FunctionCallType::Guest, + ReturnType::Int, + ); + } + + #[test] + fn test_estimate_megabyte_parameters() { + assert_estimation_accuracy( + "process_megabyte_data", + vec![ + ParameterValue::String("MB String 1: ".to_string() + &"x".repeat(1_048_576)), // 1MB string + ParameterValue::VecBytes(vec![42u8; 2_097_152]), // 2MB vector + ParameterValue::String("MB String 2: ".to_string() + &"y".repeat(1_572_864)), // 1.5MB string + ParameterValue::VecBytes(vec![128u8; 3_145_728]), // 3MB vector + ParameterValue::String("MB String 3: ".to_string() + &"z".repeat(2_097_152)), // 2MB string + ], + FunctionCallType::Host, + ReturnType::VecBytes, + ); + } +} diff --git a/src/hyperlight_guest/Cargo.toml b/src/hyperlight_guest/Cargo.toml index 73a669862..ecc516241 100644 --- a/src/hyperlight_guest/Cargo.toml +++ b/src/hyperlight_guest/Cargo.toml @@ -16,6 +16,7 @@ anyhow = { version = "1.0.99", default-features = false } serde_json = { version = "1.0", default-features = false, features = ["alloc"] } hyperlight-common = { workspace = true } hyperlight-guest-tracing = { workspace = true, default-features = false } +flatbuffers = { version= "25.2.10", default-features = false } [features] default = [] diff --git a/src/hyperlight_guest/src/guest_handle/host_comm.rs b/src/hyperlight_guest/src/guest_handle/host_comm.rs index 5719417ab..17b5300e5 100644 --- a/src/hyperlight_guest/src/guest_handle/host_comm.rs +++ b/src/hyperlight_guest/src/guest_handle/host_comm.rs @@ -19,6 +19,7 @@ use alloc::string::ToString; use alloc::vec::Vec; use core::slice::from_raw_parts; +use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; use hyperlight_common::flatbuffer_wrappers::function_types::{ ParameterValue, ReturnType, ReturnValue, @@ -27,6 +28,7 @@ use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError} use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::flatbuffer_wrappers::guest_log_level::LogLevel; use hyperlight_common::flatbuffer_wrappers::host_function_details::HostFunctionDetails; +use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; use hyperlight_common::outb::OutBAction; use super::handle::GuestHandle; @@ -92,6 +94,9 @@ impl GuestHandle { parameters: Option>, return_type: ReturnType, ) -> Result<()> { + let estimated_capacity = + estimate_flatbuffer_capacity(function_name, parameters.as_deref().unwrap_or(&[])); + let host_function_call = FunctionCall::new( function_name.to_string(), parameters, @@ -99,10 +104,9 @@ impl GuestHandle { return_type, ); - let host_function_call_buffer: Vec = host_function_call - .try_into() - .expect("Unable to serialize host function call"); + let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); + let host_function_call_buffer = host_function_call.encode(&mut builder); self.push_shared_output_data(host_function_call_buffer)?; unsafe { @@ -155,7 +159,7 @@ impl GuestHandle { .try_into() .expect("Invalid guest_error_buffer, could not be converted to a Vec"); - if let Err(e) = self.push_shared_output_data(guest_error_buffer) { + if let Err(e) = self.push_shared_output_data(&guest_error_buffer) { panic!("Unable to push guest error to shared output data: {:#?}", e); } } @@ -184,7 +188,7 @@ impl GuestHandle { .try_into() .expect("Failed to convert GuestLogData to bytes"); - self.push_shared_output_data(bytes) + self.push_shared_output_data(&bytes) .expect("Unable to push log data to shared output data"); unsafe { diff --git a/src/hyperlight_guest/src/guest_handle/io.rs b/src/hyperlight_guest/src/guest_handle/io.rs index 759c88880..d8219b270 100644 --- a/src/hyperlight_guest/src/guest_handle/io.rs +++ b/src/hyperlight_guest/src/guest_handle/io.rs @@ -16,7 +16,6 @@ limitations under the License. use alloc::format; use alloc::string::ToString; -use alloc::vec::Vec; use core::any::type_name; use core::slice::from_raw_parts_mut; @@ -93,7 +92,7 @@ impl GuestHandle { /// Pushes the given data onto the shared output data buffer. #[hyperlight_guest_tracing::trace_function] - pub fn push_shared_output_data(&self, data: Vec) -> Result<()> { + pub fn push_shared_output_data(&self, data: &[u8]) -> Result<()> { let peb_ptr = self.peb().unwrap(); let output_stack_size = unsafe { (*peb_ptr).output_stack.size as usize }; let output_stack_ptr = unsafe { (*peb_ptr).output_stack.ptr as *mut u8 }; @@ -139,7 +138,7 @@ impl GuestHandle { // write the actual data hyperlight_guest_tracing::trace!("copy data", { - odb[stack_ptr_rel as usize..stack_ptr_rel as usize + data.len()].copy_from_slice(&data); + odb[stack_ptr_rel as usize..stack_ptr_rel as usize + data.len()].copy_from_slice(data); }); // write the offset to the newly written data, to the top of the stack diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index 3d55a9f5e..7eabdeb29 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -98,7 +98,7 @@ fn internal_dispatch_function() -> Result<()> { handle.write_error(e.kind, Some(e.message.as_str())); })?; - handle.push_shared_output_data(result_vec) + handle.push_shared_output_data(&result_vec) } // This is implemented as a separate function to make sure that epilogue in the internal_dispatch_function is called before the halt() diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index 4896d9e14..3ed91820f 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -15,9 +15,13 @@ limitations under the License. */ use criterion::{Criterion, criterion_group, criterion_main}; +use flatbuffers::FlatBufferBuilder; +use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; +use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; +use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; use hyperlight_host::GuestBinary; use hyperlight_host::sandbox::{MultiUseSandbox, SandboxConfiguration, UninitializedSandbox}; -use hyperlight_testing::simple_guest_as_string; +use hyperlight_testing::{c_simple_guest_as_string, simple_guest_as_string}; fn create_uninit_sandbox() -> UninitializedSandbox { let path = simple_guest_as_string().unwrap(); @@ -133,9 +137,91 @@ fn sandbox_benchmark(c: &mut Criterion) { group.finish(); } +fn function_call_serialization_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("function_call_serialization"); + + let function_call = FunctionCall::new( + "TestFunction".to_string(), + Some(vec![ + ParameterValue::VecBytes(vec![1; 10 * 1024 * 1024]), + ParameterValue::String(String::from_utf8(vec![2; 10 * 1024 * 1024]).unwrap()), + ParameterValue::Int(42), + ParameterValue::UInt(100), + ParameterValue::Long(1000), + ParameterValue::ULong(2000), + ParameterValue::Float(521521.53), + ParameterValue::Double(432.53), + ParameterValue::Bool(true), + ParameterValue::VecBytes(vec![1; 10 * 1024 * 1024]), + ParameterValue::String(String::from_utf8(vec![2; 10 * 1024 * 1024]).unwrap()), + ]), + FunctionCallType::Guest, + ReturnType::Int, + ); + + group.bench_function("serialize_function_call", |b| { + b.iter(|| { + // We specifically want to include the time to estimate the capacity in this benchmark + let estimated_capacity = estimate_flatbuffer_capacity( + function_call.function_name.as_str(), + function_call.parameters.as_deref().unwrap_or(&[]), + ); + let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); + let serialized: &[u8] = function_call.encode(&mut builder); + std::hint::black_box(serialized); + }); + }); + + group.bench_function("deserialize_function_call", |b| { + let mut builder = FlatBufferBuilder::new(); + let bytes = function_call.clone().encode(&mut builder); + + b.iter(|| { + let deserialized: FunctionCall = bytes.try_into().unwrap(); + std::hint::black_box(deserialized); + }); + }); + + group.finish(); +} + +#[allow(clippy::disallowed_macros)] +fn sample_workloads_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("sample_workloads"); + + fn bench_24k_in_8k_out(b: &mut criterion::Bencher, guest_path: String) { + let mut cfg = SandboxConfiguration::default(); + cfg.set_input_data_size(25 * 1024); + + let mut sandbox = UninitializedSandbox::new(GuestBinary::FilePath(guest_path), Some(cfg)) + .unwrap() + .evolve() + .unwrap(); + + b.iter_with_setup( + || vec![1; 24 * 1024], + |input| { + let ret: Vec = sandbox.call("24K_in_8K_out", (input,)).unwrap(); + assert_eq!(ret.len(), 8 * 1024, "Expected output length to be 8K"); + std::hint::black_box(ret); + }, + ); + } + + group.bench_function("24K_in_8K_out_c", |b| { + bench_24k_in_8k_out(b, c_simple_guest_as_string().unwrap()); + }); + + group.bench_function("24K_in_8K_out_rust", |b| { + bench_24k_in_8k_out(b, simple_guest_as_string().unwrap()); + }); + + group.finish(); +} + criterion_group! { name = benches; config = Criterion::default(); - targets = guest_call_benchmark, sandbox_benchmark, guest_call_benchmark_large_param + targets = guest_call_benchmark, sandbox_benchmark, guest_call_benchmark_large_param, function_call_serialization_benchmark, sample_workloads_benchmark } criterion_main!(benches); diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 671dbe7db..ecc650c99 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -23,10 +23,12 @@ use std::path::Path; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; +use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; use hyperlight_common::flatbuffer_wrappers::function_types::{ ParameterValue, ReturnType, ReturnValue, }; +use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; use tracing::{Span, instrument}; use super::host_funcs::FunctionRegistry; @@ -44,7 +46,7 @@ use crate::mem::memory_region::{MemoryRegion, MemoryRegionFlags}; use crate::mem::ptr::RawPtr; use crate::mem::shared_mem::HostSharedMemory; use crate::metrics::maybe_time_and_emit_guest_call; -use crate::{HyperlightError, Result, log_then_return}; +use crate::{Result, log_then_return}; /// Global counter for assigning unique IDs to sandboxes static SANDBOX_ID_COUNTER: AtomicU64 = AtomicU64::new(0); @@ -392,6 +394,8 @@ impl MultiUseSandbox { args: Vec, ) -> Result { let res = (|| { + let estimated_capacity = estimate_flatbuffer_capacity(function_name, &args); + let fc = FunctionCall::new( function_name.to_string(), Some(args), @@ -399,13 +403,12 @@ impl MultiUseSandbox { return_type, ); - let buffer: Vec = fc.try_into().map_err(|_| { - HyperlightError::Error("Failed to serialize FunctionCall".to_string()) - })?; + let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); + let buffer = fc.encode(&mut builder); self.get_mgr_wrapper_mut() .as_mut() - .write_guest_function_call(&buffer)?; + .write_guest_function_call(buffer)?; self.vm.dispatch_call_from_host( self.dispatch_ptr.clone(), diff --git a/src/tests/c_guests/c_simpleguest/main.c b/src/tests/c_guests/c_simpleguest/main.c index 664b8441d..b7a9a596f 100644 --- a/src/tests/c_guests/c_simpleguest/main.c +++ b/src/tests/c_guests/c_simpleguest/main.c @@ -4,6 +4,7 @@ #include "stdint.h" #include "string.h" #include "stdlib.h" +#include "assert.h" // Included from hyperlight_guest_bin/third_party/printf #include "printf.h" @@ -232,6 +233,12 @@ int log_message(const char *message, int64_t level) { return -1; } +hl_Vec *twenty_four_k_in_eight_k_out(const hl_FunctionCall* params) { + hl_Vec input = params->parameters[0].value.VecBytes; + assert(input.len == 24 * 1024); + return hl_flatbuffer_result_from_Bytes(input.data, 8 * 1024); +} + HYPERLIGHT_WRAP_FUNCTION(echo, String, 1, String) // HYPERLIGHT_WRAP_FUNCTION(set_byte_array_to_zero, 1, VecBytes) is not valid for functions that return VecBytes HYPERLIGHT_WRAP_FUNCTION(print_output, Int, 1, String) @@ -260,6 +267,7 @@ HYPERLIGHT_WRAP_FUNCTION(guest_abort_with_msg, Int, 2, Int, String) HYPERLIGHT_WRAP_FUNCTION(guest_abort_with_code, Int, 1, Int) HYPERLIGHT_WRAP_FUNCTION(execute_on_stack, Int, 0) HYPERLIGHT_WRAP_FUNCTION(log_message, Int, 2, String, Long) +// HYPERLIGHT_WRAP_FUNCTION(twenty_four_k_in_eight_k_out, VecBytes, 1, VecBytes) is not valid for functions that return VecBytes void hyperlight_main(void) { @@ -295,6 +303,9 @@ void hyperlight_main(void) HYPERLIGHT_REGISTER_FUNCTION("GuestAbortWithMessage", guest_abort_with_msg); HYPERLIGHT_REGISTER_FUNCTION("ExecuteOnStack", execute_on_stack); HYPERLIGHT_REGISTER_FUNCTION("LogMessage", log_message); + // HYPERLIGHT_REGISTER_FUNCTION macro does not work for functions that return VecBytes, + // so we use hl_register_function_definition directly + hl_register_function_definition("24K_in_8K_out", twenty_four_k_in_eight_k_out, 1, (hl_ParameterType[]){hl_ParameterType_VecBytes}, hl_ReturnType_VecBytes); } // This dispatch function is only used when the host dispatches a guest function diff --git a/src/tests/rust_guests/callbackguest/Cargo.lock b/src/tests/rust_guests/callbackguest/Cargo.lock index f0fdb7982..9edddc50f 100644 --- a/src/tests/rust_guests/callbackguest/Cargo.lock +++ b/src/tests/rust_guests/callbackguest/Cargo.lock @@ -85,6 +85,7 @@ name = "hyperlight-guest" version = "0.8.0" dependencies = [ "anyhow", + "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", "serde_json", diff --git a/src/tests/rust_guests/dummyguest/Cargo.lock b/src/tests/rust_guests/dummyguest/Cargo.lock index f1c9d9e55..43bc55918 100644 --- a/src/tests/rust_guests/dummyguest/Cargo.lock +++ b/src/tests/rust_guests/dummyguest/Cargo.lock @@ -83,6 +83,7 @@ name = "hyperlight-guest" version = "0.7.0" dependencies = [ "anyhow", + "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", "serde_json", diff --git a/src/tests/rust_guests/simpleguest/Cargo.lock b/src/tests/rust_guests/simpleguest/Cargo.lock index 2b323048d..ddc53b202 100644 --- a/src/tests/rust_guests/simpleguest/Cargo.lock +++ b/src/tests/rust_guests/simpleguest/Cargo.lock @@ -75,6 +75,7 @@ name = "hyperlight-guest" version = "0.8.0" dependencies = [ "anyhow", + "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", "serde_json", diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index 5b5fe4dcf..c3b17ad58 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -725,6 +725,19 @@ fn add_to_static_and_fail(_: &FunctionCall) -> Result> { )) } +#[hyperlight_guest_tracing::trace_function] +fn twenty_four_k_in_eight_k_out(function_call: &FunctionCall) -> Result> { + if let ParameterValue::VecBytes(input) = &function_call.parameters.as_ref().unwrap()[0] { + assert!(input.len() == 24 * 1024, "Input must be 24K bytes"); + Ok(get_flatbuffer_result(&input[..8 * 1024])) + } else { + Err(HyperlightGuestError::new( + ErrorCode::GuestFunctionParameterTypeMismatch, + "Invalid parameters passed to 24K_in_8K_out".to_string(), + )) + } +} + #[hyperlight_guest_tracing::trace_function] fn violate_seccomp_filters(function_call: &FunctionCall) -> Result> { if function_call.parameters.is_none() { @@ -901,6 +914,14 @@ fn exec_mapped_buffer(function_call: &FunctionCall) -> Result> { #[no_mangle] #[hyperlight_guest_tracing::trace_function] pub extern "C" fn hyperlight_main() { + let twenty_four_k_in_def = GuestFunctionDefinition::new( + "24K_in_8K_out".to_string(), + Vec::from(&[ParameterType::VecBytes]), + ReturnType::VecBytes, + twenty_four_k_in_eight_k_out as usize, + ); + register_function(twenty_four_k_in_def); + let read_from_user_memory_def = GuestFunctionDefinition::new( "ReadFromUserMemory".to_string(), Vec::from(&[ParameterType::ULong, ParameterType::VecBytes]), diff --git a/src/tests/rust_guests/witguest/Cargo.lock b/src/tests/rust_guests/witguest/Cargo.lock index 2ecd6119f..4f0f9e057 100644 --- a/src/tests/rust_guests/witguest/Cargo.lock +++ b/src/tests/rust_guests/witguest/Cargo.lock @@ -218,6 +218,7 @@ name = "hyperlight-guest" version = "0.7.0" dependencies = [ "anyhow", + "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", "serde_json",