Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 242 additions & 23 deletions crates/wassette/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::fs::DirEntry;
use tokio::sync::{RwLock, Semaphore};
use tokio::sync::{Mutex, RwLock, Semaphore};
use tracing::{debug, info, instrument, warn};
use wasmtime::component::{Component, InstancePre};
use wasmtime::component::{Component, Instance, InstancePre};
use wasmtime::Store;

mod component_storage;
Expand Down Expand Up @@ -149,6 +149,22 @@
pub tool_names: Vec<String>,
}

/// Options for loading a component.
#[derive(Debug, Clone, Default)]
pub struct LoadOptions {
/// When true, the component's Store and Instance persist across tool calls,
/// enabling in-memory state and WASI resource continuity.
/// Concurrent calls to stateful components are serialized.
pub stateful: bool,
}

/// A persistent Store/Instance pair for stateful components.
/// The mutex ensures concurrent calls are serialized.
struct StatefulInstance {
store: Store<WassetteWasiState<WasiState>>,
instance: Instance,
}

impl ComponentRegistry {
fn new() -> Self {
Self::default()
Expand Down Expand Up @@ -304,6 +320,11 @@
oci_client: Arc<oci_wasm::WasmClient>,
http_client: reqwest::Client,
secrets_manager: Arc<SecretsManager>,
/// Cached Store/Instance pairs for stateful components.
/// The outer Arc<Mutex<>> serializes concurrent calls to the same component.
stateful_instances: Arc<Mutex<HashMap<String, StatefulInstance>>>,
/// Tracks which components are loaded in stateful mode.
stateful_components: Arc<RwLock<std::collections::HashSet<String>>>,
}

/// A representation of a loaded component instance. It contains both the base component info and a
Expand Down Expand Up @@ -370,6 +391,8 @@
oci_client,
http_client,
secrets_manager,
stateful_instances: Arc::new(Mutex::new(HashMap::new())),
stateful_components: Arc::new(RwLock::new(std::collections::HashSet::new())),
})
}

Expand Down Expand Up @@ -524,9 +547,32 @@
/// If a component with the given id already exists, it will be updated with the new component.
/// Returns rich [`ComponentLoadOutcome`] information describing the loaded
/// component and whether it replaced an existing instance.
///
/// This is a convenience method that calls [`load_component_with_options`] with default options.
#[instrument(skip(self))]
pub async fn load_component(&self, uri: &str) -> Result<ComponentLoadOutcome> {
debug!(uri, "Loading component");
self.load_component_with_options(uri, LoadOptions::default())
.await
}

/// Loads a new component from the given URI with explicit options.
///
/// If a component with the given id already exists, it will be updated with the new component.
/// Returns rich [`ComponentLoadOutcome`] information describing the loaded
/// component and whether it replaced an existing instance.
///
/// # Options
///
/// - `stateful`: When true, the component's Store and Instance persist across tool calls,
/// enabling in-memory state and WASI resource continuity. Concurrent calls to stateful
/// components are serialized.
#[instrument(skip(self))]
pub async fn load_component_with_options(
&self,
uri: &str,
options: LoadOptions,
) -> Result<ComponentLoadOutcome> {
debug!(uri, stateful = options.stateful, "Loading component");
let (component_id, resource) = self.resolve_component_resource(uri).await?;
let staged_path = self
.stage_component_artifact(&component_id, resource)
Expand All @@ -541,10 +587,38 @@
)
})?;

// Clear any existing stateful instance on reload (state will be re-created on first call)
if outcome.status == LoadResult::Replaced {
let was_stateful = self
.stateful_components
.write()
.await
.remove(&outcome.component_id);
if was_stateful {
self.stateful_instances
.lock()
.await
.remove(&outcome.component_id);
debug!(
component_id = %outcome.component_id,
"Cleared previous stateful instance on reload"
);
}
}

// Track stateful components
if options.stateful {
self.stateful_components
.write()
.await
.insert(outcome.component_id.clone());
}

info!(
component_id = %outcome.component_id,
status = ?outcome.status,
tools = ?outcome.tool_names,
stateful = options.stateful,
"Successfully loaded component"
);
Ok(outcome)
Expand All @@ -553,6 +627,8 @@
/// Unloads the component with the specified id. This removes the component from the runtime
/// and removes all associated files from disk, making it the reverse operation of load_component.
/// This function fails if any files cannot be removed (except when they don't exist).
///
/// For stateful components, this also drops the cached Store/Instance, losing all in-memory state.
#[instrument(skip(self))]
pub async fn unload_component(&self, id: &str) -> Result<()> {
debug!("Unloading component and removing files from disk");
Expand All @@ -574,6 +650,13 @@
self.registry.remove_component(id).await;
self.policy_manager.cleanup(id).await;

// Clean up stateful instance if present
let was_stateful = self.stateful_components.write().await.remove(id);
if was_stateful {
self.stateful_instances.lock().await.remove(id);
debug!(component_id = %id, "Stateful instance dropped");
}

info!(component_id = %id, "Component unloaded successfully");
Ok(())
}
Expand Down Expand Up @@ -963,15 +1046,37 @@
pub async fn execute_component_call(
&self,
component_id: &str,
function_name: &str,

Check warning on line 1049 in crates/wassette/src/lib.rs

View workflow job for this annotation

GitHub Actions / lint

Diff in /home/runner/work/wassette/wassette/crates/wassette/src/lib.rs
parameters: &str,
) -> Result<String> {
let is_stateful = self
.stateful_components
.read()
.await
.contains(component_id);

if is_stateful {
self.execute_stateful_component_call(component_id, function_name, parameters)
.await
} else {
self.execute_stateless_component_call(component_id, function_name, parameters)
.await
}
}

/// Executes a function call on a stateless component (fresh Store/Instance per call)
async fn execute_stateless_component_call(
&self,
component_id: &str,
function_name: &str,
parameters: &str,
) -> Result<String> {
let start_time = Instant::now();

debug!(
component_id = %component_id,
function_name = %function_name,
"Starting WebAssembly component execution"
"Starting stateless WebAssembly component execution"
);

let component = self
Expand Down Expand Up @@ -1006,6 +1111,124 @@
"Component instance created"
);

let result = self
.execute_function_on_instance(
component_id,
function_name,
parameters,
&mut store,
&instance,
)
.await?;

let total_duration = start_time.elapsed();

debug!(
component_id = %component_id,
function_name = %function_name,
total_duration_ms = %total_duration.as_millis(),
instantiation_ms = %instantiation_duration.as_millis(),
"Stateless WebAssembly component execution completed"
);

Ok(result)
}

/// Executes a function call on a stateful component (persistent Store/Instance)
/// Concurrent calls to the same stateful component are serialized via mutex.
async fn execute_stateful_component_call(
&self,
component_id: &str,
function_name: &str,
parameters: &str,
) -> Result<String> {
let start_time = Instant::now();

debug!(
component_id = %component_id,
function_name = %function_name,
"Starting stateful WebAssembly component execution"
);

// Acquire lock on stateful instances - this serializes concurrent calls
let mut stateful_instances = self.stateful_instances.lock().await;

let component = self
.get_component(component_id)
.await
.ok_or_else(|| anyhow!("Component not found: {}", component_id))?;

// Get or create the stateful instance
let is_first_call = !stateful_instances.contains_key(component_id);

Check warning on line 1162 in crates/wassette/src/lib.rs

View workflow job for this annotation

GitHub Actions / lint

Diff in /home/runner/work/wassette/wassette/crates/wassette/src/lib.rs

if is_first_call {
let (state, resource_limiter) =
self.get_wasi_state_for_component(component_id).await?;

let mut store = Store::new(self.runtime.as_ref(), state);

if resource_limiter.is_some() {
store.limiter(|state: &mut WassetteWasiState<WasiState>| {
state
.inner
.resource_limiter
.as_mut()
.expect("Resource limiter should be present - checked above")
});
}

let instantiation_start = Instant::now();
let instance = component.instance_pre.instantiate_async(&mut store).await?;
let instantiation_duration = instantiation_start.elapsed();

debug!(
component_id = %component_id,
instantiation_ms = %instantiation_duration.as_millis(),
"Stateful component instance created (first call)"
);

stateful_instances.insert(
component_id.to_string(),
StatefulInstance { store, instance },
);
}

let stateful_instance = stateful_instances
.get_mut(component_id)
.expect("StatefulInstance should exist - just created or already present");

let result = self
.execute_function_on_instance(
component_id,
function_name,
parameters,
&mut stateful_instance.store,
&stateful_instance.instance,
)
.await?;

let total_duration = start_time.elapsed();

debug!(
component_id = %component_id,
function_name = %function_name,
total_duration_ms = %total_duration.as_millis(),
first_call = is_first_call,
"Stateful WebAssembly component execution completed"
);

Ok(result)
}

/// Execute a function on an existing Store/Instance pair
async fn execute_function_on_instance(
&self,
component_id: &str,
function_name: &str,
parameters: &str,
store: &mut Store<WassetteWasiState<WasiState>>,
instance: &Instance,
) -> Result<String> {
// Use the new function identifier lookup instead of dot-splitting
let function_id = self
.registry
Expand All @@ -1020,11 +1243,11 @@

let func = if !interface_name.is_empty() {
let interface_index = instance
.get_export_index(&mut store, None, interface_name)
.get_export_index(&mut *store, None, interface_name)
.ok_or_else(|| anyhow!("Interface not found: {}", interface_name))?;

let function_index = instance
.get_export_index(&mut store, Some(&interface_index), func_name)
.get_export_index(&mut *store, Some(&interface_index), func_name)
.ok_or_else(|| {
anyhow!(
"Function not found in interface: {}.{}",
Expand All @@ -1034,7 +1257,7 @@
})?;

instance
.get_func(&mut store, function_index)
.get_func(&mut *store, function_index)
.ok_or_else(|| {
anyhow!(
"Function not found in interface: {}.{}",
Expand All @@ -1044,27 +1267,34 @@
})?
} else {
let func_index = instance
.get_export_index(&mut store, None, func_name)
.get_export_index(&mut *store, None, func_name)
.ok_or_else(|| anyhow!("Function not found: {}", func_name))?;
instance
.get_func(&mut store, func_index)
.get_func(&mut *store, func_index)
.ok_or_else(|| anyhow!("Function not found: {}", func_name))?
};

let params: serde_json::Value = serde_json::from_str(parameters)?;
let argument_vals = json_to_vals(&params, &func.params(&store))?;
let argument_vals = json_to_vals(&params, &func.params(&mut *store))?;

let mut results = create_placeholder_results(&func.results(&store));
let mut results = create_placeholder_results(&func.results(&mut *store));

let execution_start = Instant::now();

// Execute the WASM function and capture any errors
let call_result = func
.call_async(&mut store, &argument_vals, &mut results)
.call_async(&mut *store, &argument_vals, &mut results)
.await;

let execution_duration = execution_start.elapsed();

debug!(
component_id = %component_id,
function_name = %function_name,
execution_ms = %execution_duration.as_millis(),
"Function execution completed"
);

// If the call failed, check if it was due to a permission denial
if let Err(e) = call_result {
// Check if there was a permission error recorded during execution
Expand All @@ -1078,17 +1308,6 @@

let result_json = vals_to_json(&results);

let total_duration = start_time.elapsed();

debug!(
component_id = %component_id,
function_name = %function_name,
total_duration_ms = %total_duration.as_millis(),
instantiation_ms = %instantiation_duration.as_millis(),
execution_ms = %execution_duration.as_millis(),
"WebAssembly component execution completed"
);

if let Some(result_str) = result_json.as_str() {
Ok(result_str.to_string())
} else {
Expand Down
Loading
Loading