diff --git a/Cargo.lock b/Cargo.lock index aeb4d8866f3d..01413bff2d3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -258,6 +258,15 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +[[package]] +name = "bitmaps" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031043d04099746d8db04daf1fa424b2bc8bd69d92b25962dcde24da39ab64a2" +dependencies = [ + "typenum", +] + [[package]] name = "block-buffer" version = "0.10.2" @@ -611,6 +620,24 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "component-async-tests" +version = "0.0.0" +dependencies = [ + "anyhow", + "flate2", + "futures", + "pretty_env_logger", + "tempfile", + "test-programs-artifacts", + "tokio", + "wasi-http-draft", + "wasm-compose", + "wasmparser", + "wasmtime", + "wasmtime-wasi", +] + [[package]] name = "component-fuzz-util" version = "0.0.0" @@ -1358,6 +1385,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flagset" version = "0.4.3" @@ -1910,6 +1943,20 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "im-rc" +version = "15.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1955a75fa080c677d3972822ec4bad316169ab1cfc6c257a942c2265dbe5fe" +dependencies = [ + "bitmaps", + "rand_core", + "rand_xoshiro", + "sized-chunks", + "typenum", + "version_check", +] + [[package]] name = "indexmap" version = "1.9.1" @@ -2560,6 +2607,16 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.7.0", +] + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -2743,6 +2800,15 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core", +] + [[package]] name = "rawpointer" version = "0.2.1" @@ -3062,6 +3128,19 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap 2.7.0", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "sha1" version = "0.10.6" @@ -3135,6 +3214,16 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62ac7f900db32bf3fd12e0117dd3dc4da74bc52ebaac97f39668446d89694803" +[[package]] +name = "sized-chunks" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d69225bde7a69b235da73377861095455d298f2b970996eec25ddbb42b3d1e" +dependencies = [ + "bitmaps", + "typenum", +] + [[package]] name = "slab" version = "0.4.7" @@ -3385,15 +3474,18 @@ version = "0.0.0" dependencies = [ "anyhow", "base64 0.21.0", + "flate2", "futures", "getrandom", "libc", + "once_cell", "sha2", "url", "wasi 0.11.0+wasi-snapshot-preview1", "wasi 0.14.0+wasi-0.2.3", "wasi-nn", "wit-bindgen", + "wit-bindgen-rt 0.38.0", ] [[package]] @@ -3402,6 +3494,7 @@ version = "0.0.0" dependencies = [ "cargo_metadata", "heck 0.5.0", + "wasmparser", "wasmtime", "wit-component", ] @@ -3698,6 +3791,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "957e51f3646910546462e67d5f7599b9e4fb8acdd304b087a6494730f9eebf04" +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + [[package]] name = "untrusted" version = "0.9.0" @@ -3888,6 +3987,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "wasi-http-draft" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures", + "wasmtime", +] + [[package]] name = "wasi-nn" version = "0.6.0" @@ -3964,6 +4072,27 @@ version = "0.2.97" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ee99da9c5ba11bd675621338ef6fa52296b76b83305e9b6e5c77d4c286d6d49" +[[package]] +name = "wasm-compose" +version = "0.224.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e82eabfa1d46657d1226cf814e7cfc9423715c089d03e86d69a64ad34afd299c" +dependencies = [ + "anyhow", + "heck 0.4.1", + "im-rc", + "indexmap 2.7.0", + "log", + "petgraph", + "serde", + "serde_derive", + "serde_yaml", + "smallvec", + "wasm-encoder", + "wasmparser", + "wat", +] + [[package]] name = "wasm-encoder" version = "0.224.0" diff --git a/Cargo.toml b/Cargo.toml index 67737ef500db..daaf3eb46e8f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,7 +86,7 @@ rustix = { workspace = true, features = ["mm", "param", "process"] } [dev-dependencies] # depend again on wasmtime to activate its default features for tests -wasmtime = { workspace = true, features = ['default', 'winch', 'pulley', 'all-arch', 'call-hook', 'memory-protection-keys'] } +wasmtime = { workspace = true, features = ['default', 'winch', 'pulley', 'all-arch', 'call-hook', 'memory-protection-keys', 'component-model-async'] } env_logger = { workspace = true } log = { workspace = true } filecheck = { workspace = true } @@ -147,6 +147,7 @@ members = [ "crates/bench-api", "crates/c-api/artifact", "crates/environ/fuzz", + "crates/misc/component-async-tests", "crates/test-programs", "crates/wasi-preview1-component-adapter", "crates/wasi-preview1-component-adapter/verify", @@ -246,6 +247,7 @@ wasmtime-versioned-export-macros = { path = "crates/versioned-export-macros", ve wasmtime-slab = { path = "crates/slab", version = "=30.0.0" } component-test-util = { path = "crates/misc/component-test-util" } component-fuzz-util = { path = "crates/misc/component-fuzz-util" } +component-async-tests = { path = "crates/misc/component-async-tests" } wiggle = { path = "crates/wiggle", version = "=30.0.0", default-features = false } wiggle-macro = { path = "crates/wiggle/macro", version = "=30.0.0" } wiggle-generate = { path = "crates/wiggle/generate", version = "=30.0.0" } @@ -299,6 +301,7 @@ io-extras = "0.18.1" rustix = "0.38.43" # wit-bindgen: wit-bindgen = { version = "0.38.0", default-features = false } +wit-bindgen-rt = { version = "0.38.0", default-features = false } wit-bindgen-rust-macro = { version = "0.38.0", default-features = false } # wasm-tools family: @@ -312,6 +315,7 @@ wasm-mutate = "0.224.0" wit-parser = "0.224.0" wit-component = "0.224.0" wasm-wave = "0.224.0" +wasm-compose = "0.224.0" # Non-Bytecode Alliance maintained dependencies: # -------------------------- diff --git a/benches/call.rs b/benches/call.rs index 8e7d95aa8ffb..610b57452430 100644 --- a/benches/call.rs +++ b/benches/call.rs @@ -135,7 +135,7 @@ fn bench_host_to_wasm( typed_results: Results, ) where Params: WasmParams + ToVals + Copy, - Results: WasmResults + ToVals + Copy + PartialEq + Debug, + Results: WasmResults + ToVals + Copy + PartialEq + Debug + Sync + 'static, { // Benchmark the "typed" version, which should be faster than the versions // below. @@ -628,7 +628,8 @@ mod component { + PartialEq + Debug + Send - + Sync, + + Sync + + 'static, { // Benchmark the "typed" version. c.bench_function(&format!("component - host-to-wasm - typed - {name}"), |b| { diff --git a/cranelift/entity/src/primary.rs b/cranelift/entity/src/primary.rs index 89b9bdf18ae4..eeebcfa58502 100644 --- a/cranelift/entity/src/primary.rs +++ b/cranelift/entity/src/primary.rs @@ -72,6 +72,17 @@ where self.elems.get_mut(k.index()) } + /// Get the element at `k` if it exists, mutable version. + pub fn get_mut_or_insert_with(&mut self, k: K, f: impl FnOnce() -> V) -> &mut V { + if self.elems.get(k.index()).is_none() { + self.elems.insert(k.index(), f()); + } + + self.elems + .get_mut(k.index()) + .expect("missing existing element") + } + /// Is this map completely empty? pub fn is_empty(&self) -> bool { self.elems.is_empty() diff --git a/crates/cranelift/Cargo.toml b/crates/cranelift/Cargo.toml index 635a4eec91e8..f20a91139cbd 100644 --- a/crates/cranelift/Cargo.toml +++ b/crates/cranelift/Cargo.toml @@ -46,3 +46,4 @@ gc = ["wasmtime-environ/gc"] gc-drc = ["gc", "wasmtime-environ/gc-drc"] gc-null = ["gc", "wasmtime-environ/gc-null"] threads = ["wasmtime-environ/threads"] + diff --git a/crates/cranelift/src/compiler/component.rs b/crates/cranelift/src/compiler/component.rs index 4961693d958b..c6697ea920a4 100644 --- a/crates/cranelift/src/compiler/component.rs +++ b/crates/cranelift/src/compiler/component.rs @@ -94,6 +94,9 @@ impl<'a> TrampolineCompiler<'a> { Trampoline::AlwaysTrap => { self.translate_always_trap(); } + Trampoline::ResourceNew(ty) => self.translate_resource_new(*ty), + Trampoline::ResourceRep(ty) => self.translate_resource_rep(*ty), + Trampoline::ResourceDrop(ty) => self.translate_resource_drop(*ty), Trampoline::TaskBackpressure { instance } => { self.translate_task_backpressure_call(*instance) } @@ -114,92 +117,119 @@ impl<'a> TrampolineCompiler<'a> { } Trampoline::TaskYield { async_ } => self.translate_task_yield_call(*async_), Trampoline::SubtaskDrop { instance } => self.translate_subtask_drop_call(*instance), - Trampoline::StreamNew { ty } => { - _ = ty; - todo!() - } + Trampoline::StreamNew { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + host::stream_new, + ir::types::I64, + ), Trampoline::StreamRead { ty, options } => { - _ = (ty, options); - todo!() + if let Some(info) = self.flat_stream_element_info(*ty) { + self.translate_flat_stream_call(*ty, options, host::flat_stream_read, &info) + } else { + self.translate_future_or_stream_call( + ty.as_u32(), + Some(options), + host::stream_read, + ir::types::I64, + ) + } } Trampoline::StreamWrite { ty, options } => { - _ = (ty, options); - todo!() + if let Some(info) = self.flat_stream_element_info(*ty) { + self.translate_flat_stream_call(*ty, options, host::flat_stream_write, &info) + } else { + self.translate_future_or_stream_call( + ty.as_u32(), + Some(options), + host::stream_write, + ir::types::I64, + ) + } } Trampoline::StreamCancelRead { ty, async_ } => { - _ = (ty, async_); - todo!() + self.translate_cancel_call(ty.as_u32(), *async_, host::stream_cancel_read) } Trampoline::StreamCancelWrite { ty, async_ } => { - _ = (ty, async_); - todo!() - } - Trampoline::StreamCloseReadable { ty } => { - _ = ty; - todo!() - } - Trampoline::StreamCloseWritable { ty } => { - _ = ty; - todo!() - } - Trampoline::FutureNew { ty } => { - _ = ty; - todo!() - } - Trampoline::FutureRead { ty, options } => { - _ = (ty, options); - todo!() - } - Trampoline::FutureWrite { ty, options } => { - _ = (ty, options); - todo!() + self.translate_cancel_call(ty.as_u32(), *async_, host::stream_cancel_write) } + Trampoline::StreamCloseReadable { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + host::stream_close_readable, + ir::types::I8, + ), + Trampoline::StreamCloseWritable { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + host::stream_close_writable, + ir::types::I8, + ), + Trampoline::FutureNew { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + host::future_new, + ir::types::I64, + ), + Trampoline::FutureRead { ty, options } => self.translate_future_or_stream_call( + ty.as_u32(), + Some(&options), + host::future_read, + ir::types::I64, + ), + Trampoline::FutureWrite { ty, options } => self.translate_future_or_stream_call( + ty.as_u32(), + Some(options), + host::future_write, + ir::types::I64, + ), Trampoline::FutureCancelRead { ty, async_ } => { - _ = (ty, async_); - todo!() + self.translate_cancel_call(ty.as_u32(), *async_, host::future_cancel_read) } Trampoline::FutureCancelWrite { ty, async_ } => { - _ = (ty, async_); - todo!() - } - Trampoline::FutureCloseReadable { ty } => { - _ = ty; - todo!() - } - Trampoline::FutureCloseWritable { ty } => { - _ = ty; - todo!() - } - Trampoline::ErrorContextNew { ty, options } => { - _ = (ty, options); - todo!() + self.translate_cancel_call(ty.as_u32(), *async_, host::future_cancel_write) } - Trampoline::ErrorContextDebugMessage { ty, options } => { - _ = (ty, options); - todo!() - } - Trampoline::ErrorContextDrop { ty } => { - _ = ty; - todo!() - } - Trampoline::ResourceNew(ty) => self.translate_resource_new(*ty), - Trampoline::ResourceRep(ty) => self.translate_resource_rep(*ty), - Trampoline::ResourceDrop(ty) => self.translate_resource_drop(*ty), + Trampoline::FutureCloseReadable { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + host::future_close_readable, + ir::types::I8, + ), + Trampoline::FutureCloseWritable { ty } => self.translate_future_or_stream_call( + ty.as_u32(), + None, + host::future_close_writable, + ir::types::I8, + ), + Trampoline::ErrorContextNew { ty, options } => self.translate_error_context_call( + *ty, + options, + host::error_context_new, + ir::types::I64, + ), + Trampoline::ErrorContextDebugMessage { ty, options } => self + .translate_error_context_call( + *ty, + options, + host::error_context_debug_message, + ir::types::I8, + ), + Trampoline::ErrorContextDrop { ty } => self.translate_error_context_drop_call(*ty), Trampoline::ResourceTransferOwn => { - self.translate_resource_libcall(host::resource_transfer_own, |me, rets| { + self.translate_host_libcall(host::resource_transfer_own, |me, rets| { rets[0] = me.raise_if_negative_one(rets[0]); }) } Trampoline::ResourceTransferBorrow => { - self.translate_resource_libcall(host::resource_transfer_borrow, |me, rets| { + self.translate_host_libcall(host::resource_transfer_borrow, |me, rets| { rets[0] = me.raise_if_negative_one(rets[0]); }) } Trampoline::ResourceEnterCall => { - self.translate_resource_libcall(host::resource_enter_call, |_, _| {}) + self.translate_host_libcall(host::resource_enter_call, |_, _| {}) } Trampoline::ResourceExitCall => { - self.translate_resource_libcall(host::resource_exit_call, |me, rets| { + self.translate_host_libcall(host::resource_exit_call, |me, rets| { me.raise_if_host_trapped(rets.pop().unwrap()); }) } @@ -215,20 +245,53 @@ impl<'a> TrampolineCompiler<'a> { ir::types::I64, ), Trampoline::FutureTransfer => { - _ = host::future_transfer; - todo!() + self.translate_host_libcall(host::future_transfer, |me, rets| { + rets[0] = me.raise_if_negative_one(rets[0]); + }) } Trampoline::StreamTransfer => { - _ = host::stream_transfer; - todo!() + self.translate_host_libcall(host::stream_transfer, |me, rets| { + rets[0] = me.raise_if_negative_one(rets[0]); + }) } Trampoline::ErrorContextTransfer => { - _ = host::error_context_transfer; - todo!() + self.translate_host_libcall(host::error_context_transfer, |me, rets| { + rets[0] = me.raise_if_negative_one(rets[0]); + }) } } } + fn flat_stream_element_info(&self, ty: TypeStreamTableIndex) -> Option { + let payload = self.types[self.types[ty].ty].payload; + match payload { + None => Some(CanonicalAbiInfo { + align32: 1, + align64: 1, + flat_count: None, + size32: 0, + size64: 0, + }), + Some( + payload @ (InterfaceType::Bool + | InterfaceType::S8 + | InterfaceType::U8 + | InterfaceType::S16 + | InterfaceType::U16 + | InterfaceType::S32 + | InterfaceType::U32 + | InterfaceType::S64 + | InterfaceType::U64 + | InterfaceType::Float32 + | InterfaceType::Float64 + | InterfaceType::Char), + ) => Some(self.types.canonical_abi(&payload).clone()), + // TODO: Recursively check for other "flat" types (i.e. those without pointers or handles), + // e.g. `record`s, `variant`s, etc. which contain only flat types. + _ => None, + } + } + fn store_wasm_arguments(&mut self, args: &[Value]) -> (Value, Value) { let pointer_type = self.isa.pointer_type(); let wasm_func_ty = &self.types[self.signature].unwrap_func(); @@ -266,9 +329,9 @@ impl<'a> TrampolineCompiler<'a> { Abi::Wasm => {} Abi::Array => { - // TODO: A guest could hypothetically export the `task.return` - // intrinsic it imported, allowing the host to call it. We - // need to support that here. + // TODO: A guest could hypothetically export the same intrinsic + // it imported, allowing the host to call it directly. We need + // to support that here. self.builder.ins().trap(TRAP_INTERNAL_ASSERT); return; } @@ -469,12 +532,14 @@ impl<'a> TrampolineCompiler<'a> { instance, memory, realloc, + callback, post_return, string_encoding, - callback: _, async_, } = *options; + assert!(callback.is_none()); + // vmctx: *mut VMComponentContext host_sig.params.push(ir::AbiParam::new(pointer_type)); callee_args.push(vmctx); @@ -496,6 +561,14 @@ impl<'a> TrampolineCompiler<'a> { .iconst(ir::types::I32, i64::from(lower_ty.as_u32())), ); + // caller_instance: RuntimeComponentInstanceIndex + host_sig.params.push(ir::AbiParam::new(ir::types::I32)); + callee_args.push( + self.builder + .ins() + .iconst(ir::types::I32, i64::from(instance.as_u32())), + ); + // flags: *mut VMGlobalDefinition host_sig.params.push(ir::AbiParam::new(pointer_type)); callee_args.push( @@ -872,7 +945,7 @@ impl<'a> TrampolineCompiler<'a> { /// /// Only intended for simple trampolines and effectively acts as a bridge /// from the wasm abi to host. - fn translate_resource_libcall( + fn translate_host_libcall( &mut self, get_libcall: fn( &dyn TargetIsa, @@ -902,6 +975,189 @@ impl<'a> TrampolineCompiler<'a> { self.builder.ins().return_(&results); } + fn translate_cancel_call( + &mut self, + ty: u32, + async_: bool, + get_libcall: fn( + &dyn TargetIsa, + &mut ir::Function, + ) -> (ir::SigRef, ComponentBuiltinFunctionIndex), + ) { + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut callee_args = vec![ + vmctx, + self.builder.ins().iconst(ir::types::I32, i64::from(ty)), + self.builder + .ins() + .iconst(ir::types::I8, if async_ { 1 } else { 0 }), + ]; + + callee_args.extend(args[2..].iter().copied()); + + self.translate_intrinsic_libcall(vmctx, get_libcall, &callee_args, ir::types::I64); + } + + fn translate_future_or_stream_call( + &mut self, + ty: u32, + options: Option<&CanonicalOptions>, + get_libcall: fn( + &dyn TargetIsa, + &mut ir::Function, + ) -> (ir::SigRef, ComponentBuiltinFunctionIndex), + result: ir::types::Type, + ) { + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut callee_args = vec![vmctx]; + + if let Some(options) = options { + // memory: *mut VMMemoryDefinition + callee_args.push(self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_memory(options.memory.unwrap())).unwrap(), + )); + + // realloc: *mut VMFuncRef + callee_args.push(match options.realloc { + Some(idx) => self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_realloc(idx)).unwrap(), + ), + None => self.builder.ins().iconst(pointer_type, 0), + }); + + // string_encoding: StringEncoding + callee_args.push( + self.builder + .ins() + .iconst(ir::types::I8, i64::from(options.string_encoding as u8)), + ); + } + + callee_args.push(self.builder.ins().iconst(ir::types::I32, i64::from(ty))); + + callee_args.extend(args[2..].iter().copied()); + + self.translate_intrinsic_libcall(vmctx, get_libcall, &callee_args, result); + } + + fn translate_flat_stream_call( + &mut self, + ty: TypeStreamTableIndex, + options: &CanonicalOptions, + get_libcall: fn( + &dyn TargetIsa, + &mut ir::Function, + ) -> (ir::SigRef, ComponentBuiltinFunctionIndex), + info: &CanonicalAbiInfo, + ) { + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut callee_args = vec![ + vmctx, + self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_memory(options.memory.unwrap())).unwrap(), + ), + match options.realloc { + Some(idx) => self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_realloc(idx)).unwrap(), + ), + None => self.builder.ins().iconst(pointer_type, 0), + }, + self.builder + .ins() + .iconst(ir::types::I32, i64::from(ty.as_u32())), + self.builder + .ins() + .iconst(ir::types::I32, i64::from(info.size32)), + self.builder + .ins() + .iconst(ir::types::I32, i64::from(info.align32)), + ]; + + callee_args.extend(args[2..].iter().copied()); + + self.translate_intrinsic_libcall(vmctx, get_libcall, &callee_args, ir::types::I64); + } + + fn translate_error_context_call( + &mut self, + ty: TypeComponentLocalErrorContextTableIndex, + options: &CanonicalOptions, + get_libcall: fn( + &dyn TargetIsa, + &mut ir::Function, + ) -> (ir::SigRef, ComponentBuiltinFunctionIndex), + result: ir::types::Type, + ) { + let pointer_type = self.isa.pointer_type(); + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut callee_args = vec![ + vmctx, + self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_memory(options.memory.unwrap())).unwrap(), + ), + match options.realloc { + Some(idx) => self.builder.ins().load( + pointer_type, + MemFlags::trusted(), + vmctx, + i32::try_from(self.offsets.runtime_realloc(idx)).unwrap(), + ), + None => self.builder.ins().iconst(pointer_type, 0), + }, + self.builder + .ins() + .iconst(ir::types::I8, i64::from(options.string_encoding as u8)), + self.builder + .ins() + .iconst(ir::types::I32, i64::from(ty.as_u32())), + ]; + + callee_args.extend(args[2..].iter().copied()); + + self.translate_intrinsic_libcall(vmctx, get_libcall, &callee_args, result); + } + + fn translate_error_context_drop_call(&mut self, ty: TypeComponentLocalErrorContextTableIndex) { + let args = self.builder.func.dfg.block_params(self.block0).to_vec(); + let vmctx = args[0]; + let mut callee_args = vec![ + vmctx, + self.builder + .ins() + .iconst(ir::types::I32, i64::from(ty.as_u32())), + ]; + + callee_args.extend(args[2..].iter().copied()); + + self.translate_intrinsic_libcall( + vmctx, + host::error_context_drop, + &callee_args, + ir::types::I8, + ); + } + /// Loads a host function pointer for a libcall stored at the `offset` /// provided in the libcalls array. /// diff --git a/crates/environ/src/component.rs b/crates/environ/src/component.rs index 7fa794e6a067..b12d137cada6 100644 --- a/crates/environ/src/component.rs +++ b/crates/environ/src/component.rs @@ -99,9 +99,49 @@ macro_rules! foreach_builtin_component_function { async_enter(vmctx: vmctx, start: ptr_u8, return_: ptr_u8, caller_instance: u32, task_return_type: u32, params: u32, results: u32) -> bool; #[cfg(feature = "component-model-async")] async_exit(vmctx: vmctx, callback: ptr_u8, post_return: ptr_u8, caller_instance: u32, callee: ptr_u8, callee_instance: u32, param_count: u32, result_count: u32, flags: u32) -> u64; - + #[cfg(feature = "component-model-async")] + future_new(vmctx: vmctx, ty: u32) -> u64; + #[cfg(feature = "component-model-async")] + future_write(vmctx: vmctx, memory: ptr_u8, realloc: ptr_u8, string_encoding: u8, ty: u32, future: u32, address: u32) -> u64; + #[cfg(feature = "component-model-async")] + future_read(vmctx: vmctx, memory: ptr_u8, realloc: ptr_u8, string_encoding: u8, ty: u32, future: u32, address: u32) -> u64; + #[cfg(feature = "component-model-async")] + future_cancel_write(vmctx: vmctx, ty: u32, async_: u8, writer: u32) -> u64; + #[cfg(feature = "component-model-async")] + future_cancel_read(vmctx: vmctx, ty: u32, async_: u8, reader: u32) -> u64; + #[cfg(feature = "component-model-async")] + future_close_writable(vmctx: vmctx, ty: u32, writer: u32, error: u32) -> bool; + #[cfg(feature = "component-model-async")] + future_close_readable(vmctx: vmctx, ty: u32, reader: u32) -> bool; + #[cfg(feature = "component-model-async")] + stream_new(vmctx: vmctx, ty: u32) -> u64; + #[cfg(feature = "component-model-async")] + stream_write(vmctx: vmctx, memory: ptr_u8, realloc: ptr_u8, string_encoding: u8, ty: u32, stream: u32, address: u32, count: u32) -> u64; + #[cfg(feature = "component-model-async")] + stream_read(vmctx: vmctx, memory: ptr_u8, realloc: ptr_u8, string_encoding: u8, ty: u32, stream: u32, address: u32, count: u32) -> u64; + #[cfg(feature = "component-model-async")] + stream_cancel_write(vmctx: vmctx, ty: u32, async_: u8, writer: u32) -> u64; + #[cfg(feature = "component-model-async")] + stream_cancel_read(vmctx: vmctx, ty: u32, async_: u8, reader: u32) -> u64; + #[cfg(feature = "component-model-async")] + stream_close_writable(vmctx: vmctx, ty: u32, writer: u32, error: u32) -> bool; + #[cfg(feature = "component-model-async")] + stream_close_readable(vmctx: vmctx, ty: u32, reader: u32) -> bool; + #[cfg(feature = "component-model-async")] + flat_stream_write(vmctx: vmctx, memory: ptr_u8, realloc: ptr_u8, ty: u32, payload_size: u32, payload_align: u32, stream: u32, address: u32, count: u32) -> u64; + #[cfg(feature = "component-model-async")] + flat_stream_read(vmctx: vmctx, memory: ptr_u8, realloc: ptr_u8, ty: u32, payload_size: u32, payload_align: u32, stream: u32, address: u32, count: u32) -> u64; + #[cfg(feature = "component-model-async")] + error_context_new(vmctx: vmctx, memory: ptr_u8, realloc: ptr_u8, string_encoding: u8, ty: u32, debug_msg_address: u32, debug_msg_len: u32) -> u64; + #[cfg(feature = "component-model-async")] + error_context_debug_message(vmctx: vmctx, memory: ptr_u8, realloc: ptr_u8, string_encoding: u8, ty: u32, err_ctx_handle: u32, debug_msg_address: u32) -> bool; + #[cfg(feature = "component-model-async")] + error_context_drop(vmctx: vmctx, ty: u32, err_ctx_handle: u32) -> bool; + #[cfg(feature = "component-model-async")] future_transfer(vmctx: vmctx, src_idx: u32, src_table: u32, dst_table: u32) -> u64; + #[cfg(feature = "component-model-async")] stream_transfer(vmctx: vmctx, src_idx: u32, src_table: u32, dst_table: u32) -> u64; + #[cfg(feature = "component-model-async")] error_context_transfer(vmctx: vmctx, src_idx: u32, src_table: u32, dst_table: u32) -> u64; trap(vmctx: vmctx, code: u8); diff --git a/crates/environ/src/component/translate/adapt.rs b/crates/environ/src/component/translate/adapt.rs index a2d7020691ef..7033269e83fe 100644 --- a/crates/environ/src/component/translate/adapt.rs +++ b/crates/environ/src/component/translate/adapt.rs @@ -196,6 +196,8 @@ impl<'data> Translator<'_, 'data> { names.push(name); } let wasm = module.encode(); + std::fs::write("/tmp/adapter.wasm", &wasm).unwrap(); + wasmparser::Validator::new().validate_all(&wasm).unwrap(); let imports = module.imports().to_vec(); // Extend the lifetime of the owned `wasm: Vec` on the stack to diff --git a/crates/environ/src/component/types_builder.rs b/crates/environ/src/component/types_builder.rs index a1579ed6882a..38cb2a2d0010 100644 --- a/crates/environ/src/component/types_builder.rs +++ b/crates/environ/src/component/types_builder.rs @@ -429,6 +429,11 @@ impl ComponentTypesBuilder { Ok(ret) } + /// Retrieve Wasmtime's type representation of the `error-context` type. + pub fn error_context_type(&mut self) -> Result { + self.error_context_table_type() + } + pub(crate) fn valtype( &mut self, types: TypesRef<'_>, diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index 3fdfde7085f7..1c48143edeee 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -108,45 +108,46 @@ pub(super) fn compile(module: &mut Module<'_>, adapter: &AdapterData) { ); ( - Compiler::new(module, result, lower_sig.params.len() as u32), + Compiler::new( + module, + result, + lower_sig.params.len() as u32, + emit_resource_call, + ), lower_sig, lift_sig, ) } - let async_start_adapter = |module: &mut Module, param_globals| { - let sig = module.types.async_start_signature(&adapter.lift); - let ty = module.core_types.function(&sig.params, &sig.results); - let result = module.funcs.push(Function::new( - Some(format!("[async-start]{}", adapter.name)), - ty, - )); + let async_start_adapter = + |module: &mut Module, param_globals| { + let sig = module.types.async_start_signature(&adapter.lift); + let ty = module.core_types.function(&sig.params, &sig.results); + let result = module.funcs.push(Function::new( + Some(format!("[async-start]{}", adapter.name)), + ty, + )); - Compiler::new(module, result, sig.params.len() as u32).compile_async_start_adapter( - adapter, - &sig, - param_globals, - ); + Compiler::new(module, result, sig.params.len() as u32, false) + .compile_async_start_adapter(adapter, &sig, param_globals); - result - }; + result + }; - let async_return_adapter = |module: &mut Module, result_globals| { - let sig = module.types.async_return_signature(&adapter.lift); - let ty = module.core_types.function(&sig.params, &sig.results); - let result = module.funcs.push(Function::new( - Some(format!("[async-return]{}", adapter.name)), - ty, - )); + let async_return_adapter = + |module: &mut Module, result_globals| { + let sig = module.types.async_return_signature(&adapter.lift); + let ty = module.core_types.function(&sig.params, &sig.results); + let result = module.funcs.push(Function::new( + Some(format!("[async-return]{}", adapter.name)), + ty, + )); - Compiler::new(module, result, sig.params.len() as u32).compile_async_return_adapter( - adapter, - &sig, - result_globals, - ); + Compiler::new(module, result, sig.params.len() as u32, false) + .compile_async_return_adapter(adapter, &sig, result_globals); - result - }; + result + }; match (adapter.lower.options.async_, adapter.lift.options.async_) { (false, false) => { @@ -194,6 +195,14 @@ pub(super) fn compile(module: &mut Module<'_>, adapter: &AdapterData) { // Similarly, the `async-return` function may write its result to // global variables from which the adapter function can read and // return them via the stack to the caller. + // + // TODO: More than one of these calls can be made from the same + // instance concurrently when the caller instance was itself called + // via a async-without-callback-lifted export. In that case, these + // globals could be clobbered by other calls between when we write + // to them and read from them. We need to refactor this to save the + // values in host-managed, task-local storage rather than global + // variables. let lower_sig = module.types.signature(&adapter.lower, Context::Lower); let param_globals = if lower_sig.params_indirect { None @@ -403,7 +412,12 @@ struct Memory<'a> { } impl<'a, 'b> Compiler<'a, 'b> { - fn new(module: &'b mut Module<'a>, result: FunctionId, nlocals: u32) -> Self { + fn new( + module: &'b mut Module<'a>, + result: FunctionId, + nlocals: u32, + emit_resource_call: bool, + ) -> Self { Self { types: module.types, module, @@ -413,7 +427,7 @@ impl<'a, 'b> Compiler<'a, 'b> { free_locals: HashMap::new(), traps: Vec::new(), fuel: INITIAL_FUEL, - emit_resource_call: false, + emit_resource_call, } } @@ -440,7 +454,7 @@ impl<'a, 'b> Compiler<'a, 'b> { i32::try_from(adapter.lower.instance.as_u32()).unwrap(), )); self.instruction(I32Const( - i32::try_from(self.types[adapter.lift.ty].params.as_u32()).unwrap(), + i32::try_from(self.types[adapter.lift.ty].results.as_u32()).unwrap(), )); self.instruction(LocalGet(0)); self.instruction(LocalGet(1)); @@ -495,7 +509,7 @@ impl<'a, 'b> Compiler<'a, 'b> { i32::try_from(adapter.lower.instance.as_u32()).unwrap(), )); self.instruction(I32Const( - i32::try_from(self.types[adapter.lift.ty].params.as_u32()).unwrap(), + i32::try_from(self.types[adapter.lift.ty].results.as_u32()).unwrap(), )); let results_local = if let Some(globals) = param_globals { @@ -575,7 +589,7 @@ impl<'a, 'b> Compiler<'a, 'b> { i32::try_from(adapter.lower.instance.as_u32()).unwrap(), )); self.instruction(I32Const( - i32::try_from(self.types[adapter.lift.ty].params.as_u32()).unwrap(), + i32::try_from(self.types[adapter.lift.ty].results.as_u32()).unwrap(), )); self.instruction(LocalGet(0)); self.instruction(LocalGet(1)); diff --git a/crates/fuzzing/src/generators/component_types.rs b/crates/fuzzing/src/generators/component_types.rs index b184fa28e7a9..e3df7a2cceb9 100644 --- a/crates/fuzzing/src/generators/component_types.rs +++ b/crates/fuzzing/src/generators/component_types.rs @@ -108,8 +108,10 @@ pub fn arbitrary_val(ty: &component::Type, input: &mut Unstructured) -> arbitrar .collect::>()?, ), - // Resources aren't fuzzed at this time. - Type::Own(_) | Type::Borrow(_) => unreachable!(), + // Resources, futures, streams, and error contexts aren't fuzzed at this time. + Type::Own(_) | Type::Borrow(_) | Type::Future(_) | Type::Stream(_) | Type::ErrorContext => { + unreachable!() + } }) } @@ -120,8 +122,25 @@ pub fn static_api_test<'a, P, R>( declarations: &Declarations, ) -> arbitrary::Result<()> where - P: ComponentNamedList + Lift + Lower + Clone + PartialEq + Debug + Arbitrary<'a> + 'static, - R: ComponentNamedList + Lift + Lower + Clone + PartialEq + Debug + Arbitrary<'a> + 'static, + P: ComponentNamedList + + Lift + + Lower + + Clone + + PartialEq + + Debug + + Arbitrary<'a> + + Send + + 'static, + R: ComponentNamedList + + Lift + + Lower + + Clone + + PartialEq + + Debug + + Arbitrary<'a> + + Send + + Sync + + 'static, { crate::init_fuzzing(); @@ -139,7 +158,7 @@ where .root() .func_wrap( IMPORT_FUNCTION, - |cx: StoreContextMut<'_, Box>, params: P| { + |cx: StoreContextMut<'_, Box>, params: P| { log::trace!("received parameters {params:?}"); let data: &(P, R) = cx.data().downcast_ref().unwrap(); let (expected_params, result) = data; @@ -149,7 +168,7 @@ where }, ) .unwrap(); - let mut store: Store> = Store::new(&engine, Box::new(())); + let mut store: Store> = Store::new(&engine, Box::new(())); let instance = linker.instantiate(&mut store, &component).unwrap(); let func = instance .get_typed_func::(&mut store, EXPORT_FUNCTION) diff --git a/crates/misc/component-async-tests/Cargo.toml b/crates/misc/component-async-tests/Cargo.toml new file mode 100644 index 000000000000..80cfe4da8274 --- /dev/null +++ b/crates/misc/component-async-tests/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "component-async-tests" +authors = ["The Wasmtime Project Developers"] +license = "Apache-2.0 WITH LLVM-exception" +version = "0.0.0" +edition.workspace = true +rust-version.workspace = true +publish = false + +[dev-dependencies] +anyhow = { workspace = true } +flate2 = "1.0.30" +futures = { workspace = true } +pretty_env_logger = { workspace = true } +tempfile = { workspace = true } +test-programs-artifacts = { workspace = true } +tokio = { workspace = true, features = ["fs", "process", "macros", "rt-multi-thread", "time"] } +wasi-http-draft = { path = "http" } +wasm-compose = { workspace = true } +wasmparser = { workspace = true } +wasmtime = { workspace = true, features = ["component-model-async"] } +wasmtime-wasi = { workspace = true } + diff --git a/crates/misc/component-async-tests/http/Cargo.toml b/crates/misc/component-async-tests/http/Cargo.toml new file mode 100644 index 000000000000..c7c5d23292cf --- /dev/null +++ b/crates/misc/component-async-tests/http/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "wasi-http-draft" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +anyhow = { workspace = true } +futures = { workspace = true } +wasmtime = { workspace = true, features = ["component-model-async"] } diff --git a/crates/misc/component-async-tests/http/src/lib.rs b/crates/misc/component-async-tests/http/src/lib.rs new file mode 100644 index 000000000000..930fcded6c11 --- /dev/null +++ b/crates/misc/component-async-tests/http/src/lib.rs @@ -0,0 +1,565 @@ +#![deny(warnings)] + +wasmtime::component::bindgen!({ + trappable_imports: true, + path: "../wit", + interfaces: " + import wasi:http/types@0.3.0-draft; + import wasi:http/handler@0.3.0-draft; + ", + concurrent_imports: true, + async: { + only_imports: [ + "wasi:http/types@0.3.0-draft#[static]body.finish", + "wasi:http/handler@0.3.0-draft#handle", + ] + }, + with: { + "wasi:http/types/body": Body, + "wasi:http/types/request": Request, + "wasi:http/types/request-options": RequestOptions, + "wasi:http/types/response": Response, + "wasi:http/types/fields": Fields, + } +}); + +use { + anyhow::anyhow, + std::{fmt, future::Future, mem}, + wasi::http::types::{ErrorCode, HeaderError, Method, RequestOptionsError, Scheme}, + wasmtime::{ + component::{ + self, ErrorContext, FutureReader, Linker, Resource, ResourceTable, StreamReader, + }, + AsContextMut, StoreContextMut, + }, +}; + +impl fmt::Display for Scheme { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + Scheme::Http => "http", + Scheme::Https => "https", + Scheme::Other(s) => s, + } + ) + } +} + +pub trait WasiHttpView: Send + Sized { + type Data; + + fn table(&mut self) -> &mut ResourceTable; + + fn send_request( + store: StoreContextMut<'_, Self::Data>, + request: Resource, + ) -> impl Future< + Output = impl FnOnce( + StoreContextMut<'_, Self::Data>, + ) -> wasmtime::Result, ErrorCode>> + + Send + + Sync + + 'static, + > + Send + + Sync + + 'static; +} + +impl WasiHttpView for &mut T { + type Data = T::Data; + + fn table(&mut self) -> &mut ResourceTable { + (*self).table() + } + + fn send_request( + store: StoreContextMut<'_, Self::Data>, + request: Resource, + ) -> impl Future< + Output = impl FnOnce( + StoreContextMut<'_, Self::Data>, + ) -> wasmtime::Result, ErrorCode>> + + Send + + Sync + + 'static, + > + Send + + Sync + + 'static { + T::send_request(store, request) + } +} + +pub struct WasiHttpImpl(pub T); + +impl WasiHttpView for WasiHttpImpl { + type Data = T::Data; + + fn table(&mut self) -> &mut ResourceTable { + self.0.table() + } + + fn send_request( + store: StoreContextMut<'_, Self::Data>, + request: Resource, + ) -> impl Future< + Output = impl FnOnce( + StoreContextMut<'_, Self::Data>, + ) -> wasmtime::Result, ErrorCode>> + + Send + + Sync + + 'static, + > + Send + + Sync + + 'static { + T::send_request(store, request) + } +} + +pub struct Body { + pub stream: Option>, + pub trailers: Option>>, +} + +#[derive(Clone)] +pub struct Fields(pub Vec<(String, Vec)>); + +#[derive(Default, Copy, Clone)] +pub struct RequestOptions { + pub connect_timeout: Option, + pub first_byte_timeout: Option, + pub between_bytes_timeout: Option, +} + +pub struct Request { + pub method: Method, + pub scheme: Option, + pub path_with_query: Option, + pub authority: Option, + pub headers: Fields, + pub body: Body, + pub options: Option, +} + +pub struct Response { + pub status_code: u16, + pub headers: Fields, + pub body: Body, +} + +impl wasi::http::types::HostFields for WasiHttpImpl { + fn new(&mut self) -> wasmtime::Result> { + Ok(self.table().push(Fields(Vec::new()))?) + } + + fn from_list( + &mut self, + list: Vec<(String, Vec)>, + ) -> wasmtime::Result, HeaderError>> { + Ok(Ok(self.table().push(Fields(list))?)) + } + + fn get(&mut self, this: Resource, key: String) -> wasmtime::Result>> { + Ok(self + .table() + .get(&this)? + .0 + .iter() + .filter(|(k, _)| *k == key) + .map(|(_, v)| v.clone()) + .collect()) + } + + fn has(&mut self, this: Resource, key: String) -> wasmtime::Result { + Ok(self.table().get(&this)?.0.iter().any(|(k, _)| *k == key)) + } + + fn set( + &mut self, + this: Resource, + key: String, + values: Vec>, + ) -> wasmtime::Result> { + let fields = self.table().get_mut(&this)?; + fields.0.retain(|(k, _)| *k != key); + fields + .0 + .extend(values.into_iter().map(|v| (key.clone(), v))); + Ok(Ok(())) + } + + fn delete( + &mut self, + this: Resource, + key: String, + ) -> wasmtime::Result>, HeaderError>> { + let fields = self.table().get_mut(&this)?; + let (matched, unmatched) = mem::take(&mut fields.0) + .into_iter() + .partition(|(k, _)| *k == key); + fields.0 = unmatched; + Ok(Ok(matched.into_iter().map(|(_, v)| v).collect())) + } + + fn append( + &mut self, + this: Resource, + key: String, + value: Vec, + ) -> wasmtime::Result> { + self.table().get_mut(&this)?.0.push((key, value)); + Ok(Ok(())) + } + + fn entries(&mut self, this: Resource) -> wasmtime::Result)>> { + Ok(self.table().get(&this)?.0.clone()) + } + + fn clone(&mut self, this: Resource) -> wasmtime::Result> { + let entries = self.table().get(&this)?.0.clone(); + Ok(self.table().push(Fields(entries))?) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table().delete(this)?; + Ok(()) + } +} + +impl wasi::http::types::HostBody for WasiHttpImpl +where + T::Data: WasiHttpView, +{ + type BodyData = T::Data; + + fn new( + &mut self, + stream: StreamReader, + trailers: Option>>, + ) -> wasmtime::Result> { + Ok(self.table().push(Body { + stream: Some(stream), + trailers, + })?) + } + + fn stream(&mut self, this: Resource) -> wasmtime::Result, ()>> { + // TODO: This should return a child handle + let stream = self.table().get_mut(&this)?.stream.take().ok_or_else(|| { + anyhow!("todo: allow wasi:http/types#body.stream to be called multiple times") + })?; + + Ok(Ok(stream)) + } + + fn finish( + mut store: StoreContextMut<'_, Self::BodyData>, + this: Resource, + ) -> impl Future< + Output = impl FnOnce( + StoreContextMut<'_, Self::BodyData>, + ) + -> wasmtime::Result>, ErrorCode>> + + 'static, + > + Send + + Sync + + 'static { + let trailers = (|| { + let trailers = store.data_mut().table().delete(this)?.trailers; + trailers + .map(|v| v.read(store.as_context_mut()).map(|v| v.into_future())) + .transpose() + })(); + async move { + let trailers = match trailers { + Ok(Some(trailers)) => Ok(trailers.await), + Ok(None) => Ok(None), + Err(e) => Err(e), + }; + + component::for_any(move |_| Ok(Ok(trailers?))) + } + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table().delete(this)?; + Ok(()) + } +} + +impl wasi::http::types::HostRequest for WasiHttpImpl { + fn new( + &mut self, + headers: Resource, + body: Resource, + options: Option>, + ) -> wasmtime::Result> { + let headers = self.table().delete(headers)?; + let body = self.table().delete(body)?; + let options = if let Some(options) = options { + Some(self.table().delete(options)?) + } else { + None + }; + + Ok(self.table().push(Request { + method: Method::Get, + scheme: None, + path_with_query: None, + authority: None, + headers, + body, + options, + })?) + } + + fn method(&mut self, this: Resource) -> wasmtime::Result { + Ok(self.table().get(&this)?.method.clone()) + } + + fn set_method( + &mut self, + this: Resource, + method: Method, + ) -> wasmtime::Result> { + self.table().get_mut(&this)?.method = method; + Ok(Ok(())) + } + + fn scheme(&mut self, this: Resource) -> wasmtime::Result> { + Ok(self.table().get(&this)?.scheme.clone()) + } + + fn set_scheme( + &mut self, + this: Resource, + scheme: Option, + ) -> wasmtime::Result> { + self.table().get_mut(&this)?.scheme = scheme; + Ok(Ok(())) + } + + fn path_with_query(&mut self, this: Resource) -> wasmtime::Result> { + Ok(self.table().get(&this)?.path_with_query.clone()) + } + + fn set_path_with_query( + &mut self, + this: Resource, + path_with_query: Option, + ) -> wasmtime::Result> { + self.table().get_mut(&this)?.path_with_query = path_with_query; + Ok(Ok(())) + } + + fn authority(&mut self, this: Resource) -> wasmtime::Result> { + Ok(self.table().get(&this)?.authority.clone()) + } + + fn set_authority( + &mut self, + this: Resource, + authority: Option, + ) -> wasmtime::Result> { + self.table().get_mut(&this)?.authority = authority; + Ok(Ok(())) + } + + fn options( + &mut self, + this: Resource, + ) -> wasmtime::Result>> { + // TODO: This should return an immutable child handle + let options = self.table().get(&this)?.options; + Ok(if let Some(options) = options { + Some(self.table().push(options)?) + } else { + None + }) + } + + fn headers(&mut self, this: Resource) -> wasmtime::Result> { + // TODO: This should return an immutable child handle + let headers = self.table().get(&this)?.headers.clone(); + Ok(self.table().push(headers)?) + } + + fn body(&mut self, _this: Resource) -> wasmtime::Result> { + Err(anyhow!("todo: implement wasi:http/types#request.body")) + } + + fn into_parts( + &mut self, + this: Resource, + ) -> wasmtime::Result<(Resource, Resource)> { + let request = self.table().delete(this)?; + let headers = self.table().push(request.headers)?; + let body = self.table().push(request.body)?; + Ok((headers, body)) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table().delete(this)?; + Ok(()) + } +} + +impl wasi::http::types::HostResponse for WasiHttpImpl { + fn new( + &mut self, + headers: Resource, + body: Resource, + ) -> wasmtime::Result> { + let headers = self.table().delete(headers)?; + let body = self.table().delete(body)?; + + Ok(self.table().push(Response { + status_code: 200, + headers, + body, + })?) + } + + fn status_code(&mut self, this: Resource) -> wasmtime::Result { + Ok(self.table().get(&this)?.status_code) + } + + fn set_status_code( + &mut self, + this: Resource, + status_code: u16, + ) -> wasmtime::Result> { + self.table().get_mut(&this)?.status_code = status_code; + Ok(Ok(())) + } + + fn headers(&mut self, this: Resource) -> wasmtime::Result> { + // TODO: This should return an immutable child handle + let headers = self.table().get(&this)?.headers.clone(); + Ok(self.table().push(headers)?) + } + + fn body(&mut self, _this: Resource) -> wasmtime::Result> { + Err(anyhow!("todo: implement wasi:http/types#response.body")) + } + + fn into_parts( + &mut self, + this: Resource, + ) -> wasmtime::Result<(Resource, Resource)> { + let response = self.table().delete(this)?; + let headers = self.table().push(response.headers)?; + let body = self.table().push(response.body)?; + Ok((headers, body)) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table().delete(this)?; + Ok(()) + } +} + +impl wasi::http::types::HostRequestOptions for WasiHttpImpl { + fn new(&mut self) -> wasmtime::Result> { + Ok(self.table().push(RequestOptions::default())?) + } + + fn connect_timeout(&mut self, this: Resource) -> wasmtime::Result> { + Ok(self.table().get(&this)?.connect_timeout) + } + + fn set_connect_timeout( + &mut self, + this: Resource, + connect_timeout: Option, + ) -> wasmtime::Result> { + self.table().get_mut(&this)?.connect_timeout = connect_timeout; + Ok(Ok(())) + } + + fn first_byte_timeout( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + Ok(self.table().get(&this)?.first_byte_timeout) + } + + fn set_first_byte_timeout( + &mut self, + this: Resource, + first_byte_timeout: Option, + ) -> wasmtime::Result> { + self.table().get_mut(&this)?.first_byte_timeout = first_byte_timeout; + Ok(Ok(())) + } + + fn between_bytes_timeout( + &mut self, + this: Resource, + ) -> wasmtime::Result> { + Ok(self.table().get(&this)?.between_bytes_timeout) + } + + fn set_between_bytes_timeout( + &mut self, + this: Resource, + between_bytes_timeout: Option, + ) -> wasmtime::Result> { + self.table().get_mut(&this)?.between_bytes_timeout = between_bytes_timeout; + Ok(Ok(())) + } + + fn drop(&mut self, this: Resource) -> wasmtime::Result<()> { + self.table().delete(this)?; + Ok(()) + } +} + +impl wasi::http::types::Host for WasiHttpImpl +where + T::Data: WasiHttpView, +{ + fn http_error_code(&mut self, _error: ErrorContext) -> wasmtime::Result> { + Err(anyhow!("todo: implement wasi:http/types#http-error-code")) + } +} + +impl wasi::http::handler::Host for WasiHttpImpl { + type Data = T::Data; + + fn handle( + store: StoreContextMut<'_, Self::Data>, + request: Resource, + ) -> impl Future< + Output = impl FnOnce( + StoreContextMut<'_, Self::Data>, + ) -> wasmtime::Result, ErrorCode>> + + Send + + Sync + + 'static, + > + Send + + Sync + + 'static { + Self::send_request(store, request) + } +} + +pub fn add_to_linker + 'static>( + linker: &mut Linker, +) -> wasmtime::Result<()> +where + ::Data: WasiHttpView, +{ + wasi::http::types::add_to_linker_get_host(linker, annotate_http(|ctx| WasiHttpImpl(ctx)))?; + wasi::http::handler::add_to_linker_get_host(linker, annotate_http(|ctx| WasiHttpImpl(ctx))) +} + +fn annotate_http(val: F) -> F +where + F: Fn(&mut T) -> WasiHttpImpl<&mut T>, +{ + val +} diff --git a/crates/misc/component-async-tests/src/lib.rs b/crates/misc/component-async-tests/src/lib.rs new file mode 100644 index 000000000000..b4688fcfcc72 --- /dev/null +++ b/crates/misc/component-async-tests/src/lib.rs @@ -0,0 +1,1474 @@ +#![deny(warnings)] + +#[cfg(test)] +mod test { + use { + anyhow::{anyhow, Result}, + futures::future, + std::{ + future::Future, + ops::DerefMut, + sync::{Arc, Mutex, Once}, + task::{Poll, Waker}, + time::Duration, + }, + tokio::fs, + transmit::exports::local::local::transmit::Control, + wasi_http_draft::{ + wasi::http::types::{Body, ErrorCode, Method, Request, Response, Scheme}, + Fields, WasiHttpView, + }, + wasm_compose::composer::ComponentComposer, + wasmtime::{ + component::{ + self, Component, FutureReader, Instance, Linker, Promise, PromisesUnordered, + Resource, ResourceTable, StreamReader, StreamWriter, Val, + }, + AsContextMut, Config, Engine, Store, StoreContextMut, + }, + wasmtime_wasi::{IoView, WasiCtx, WasiCtxBuilder, WasiView}, + }; + + macro_rules! assert_test_exists { + ($name:ident) => { + #[expect(unused_imports, reason = "just here to ensure a name exists")] + use self::$name as _; + }; + } + + test_programs_artifacts::foreach_async!(assert_test_exists); + + mod round_trip { + wasmtime::component::bindgen!({ + trappable_imports: true, + path: "wit", + world: "round-trip", + concurrent_imports: true, + concurrent_exports: true, + async: true, + }); + } + + fn init_logger() { + static ONCE: Once = Once::new(); + ONCE.call_once(pretty_env_logger::init); + } + + struct Ctx { + wasi: WasiCtx, + table: ResourceTable, + wakers: Arc>>>, + continue_: bool, + } + + impl IoView for Ctx { + fn table(&mut self) -> &mut ResourceTable { + &mut self.table + } + } + + impl WasiView for Ctx { + fn ctx(&mut self) -> &mut WasiCtx { + &mut self.wasi + } + } + + impl round_trip::local::local::baz::Host for Ctx { + type Data = Ctx; + + #[allow(clippy::manual_async_fn)] + fn foo( + _: StoreContextMut<'_, Self>, + s: String, + ) -> impl Future< + Output = impl FnOnce(StoreContextMut<'_, Self>) -> wasmtime::Result + 'static, + > + Send + + 'static { + async move { + tokio::time::sleep(Duration::from_millis(10)).await; + component::for_any(move |_: StoreContextMut<'_, Self>| { + Ok(format!("{s} - entered host - exited host")) + }) + } + } + } + + impl round_trip_direct::RoundTripDirectImports for Ctx { + type Data = Ctx; + + #[allow(clippy::manual_async_fn)] + fn foo( + _: StoreContextMut<'_, Self>, + s: String, + ) -> impl Future< + Output = impl FnOnce(StoreContextMut<'_, Self>) -> wasmtime::Result + 'static, + > + Send + + 'static { + async move { + tokio::time::sleep(Duration::from_millis(10)).await; + component::for_any(move |_: StoreContextMut<'_, Self>| { + Ok(format!("{s} - entered host - exited host")) + }) + } + } + } + + pub struct MyX; + + impl borrowing_host::local::local::borrowing_types::HostX for Ctx { + fn new(&mut self) -> Result> { + Ok(IoView::table(self).push(MyX)?) + } + + fn foo(&mut self, x: Resource) -> Result<()> { + _ = IoView::table(self).get(&x)?; + Ok(()) + } + + fn drop(&mut self, x: Resource) -> Result<()> { + IoView::table(self).delete(x)?; + Ok(()) + } + } + + impl borrowing_host::local::local::borrowing_types::Host for Ctx {} + + async fn test_round_trip(component: &[u8], input: &str, expected_output: &str) -> Result<()> { + init_logger(); + + let mut config = Config::new(); + config.debug_info(true); + config.cranelift_debug_verifier(true); + config.wasm_component_model(true); + config.wasm_component_model_async(true); + config.async_support(true); + + let engine = Engine::new(&config)?; + + let make_store = || { + Store::new( + &engine, + Ctx { + wasi: WasiCtxBuilder::new().inherit_stdio().build(), + table: ResourceTable::default(), + continue_: false, + wakers: Arc::new(Mutex::new(None)), + }, + ) + }; + + let component = Component::new(&engine, component)?; + + // First, test the `wasmtime-wit-bindgen` static API: + { + let mut linker = Linker::new(&engine); + + wasmtime_wasi::add_to_linker_async(&mut linker)?; + round_trip::RoundTrip::add_to_linker(&mut linker, |ctx| ctx)?; + + let mut store = make_store(); + + let round_trip = + round_trip::RoundTrip::instantiate_async(&mut store, &component, &linker).await?; + + // Start three concurrent calls and then join them all: + let mut promises = PromisesUnordered::new(); + for _ in 0..3 { + promises.push( + round_trip + .local_local_baz() + .call_foo(&mut store, input.to_owned()) + .await?, + ); + } + + while let Some(value) = promises.next(&mut store).await? { + assert_eq!(expected_output, &value); + } + } + + // Now do it again using the dynamic API (except for WASI, where we stick with the static API): + { + let mut linker = Linker::new(&engine); + + wasmtime_wasi::add_to_linker_async(&mut linker)?; + linker + .root() + .instance("local:local/baz")? + .func_new_concurrent("foo", |_, params| async move { + tokio::time::sleep(Duration::from_millis(10)).await; + component::for_any(move |_: StoreContextMut<'_, Ctx>| { + let Some(Val::String(s)) = params.into_iter().next() else { + unreachable!() + }; + Ok(vec![Val::String(format!( + "{s} - entered host - exited host" + ))]) + }) + })?; + + let mut store = make_store(); + + let instance = linker.instantiate_async(&mut store, &component).await?; + let baz_instance = instance + .get_export(&mut store, None, "local:local/baz") + .ok_or_else(|| anyhow!("can't find `local:local/baz` in instance"))?; + let foo_function = instance + .get_export(&mut store, Some(&baz_instance), "foo") + .ok_or_else(|| anyhow!("can't find `foo` in instance"))?; + let foo_function = instance + .get_func(&mut store, foo_function) + .ok_or_else(|| anyhow!("can't find `foo` in instance"))?; + + // Start three concurrent calls and then join them all: + let mut promises = PromisesUnordered::new(); + for _ in 0..3 { + promises.push( + foo_function + .call_concurrent(&mut store, vec![Val::String(input.to_owned())]) + .await?, + ); + } + + while let Some(value) = promises.next(&mut store).await? { + let Some(Val::String(value)) = value.into_iter().next() else { + unreachable!() + }; + assert_eq!(expected_output, &value); + } + } + + Ok(()) + } + + /// Compose two components + /// + /// a is the "root" component, and b is composed into it + async fn compose(a: &[u8], b: &[u8]) -> Result> { + let dir = tempfile::tempdir()?; + + let a_file = dir.path().join("a.wasm"); + fs::write(&a_file, a).await?; + + let b_file = dir.path().join("b.wasm"); + fs::write(&b_file, b).await?; + + ComponentComposer::new( + &a_file, + &wasm_compose::config::Config { + dir: dir.path().to_owned(), + definitions: vec![b_file.to_owned()], + ..Default::default() + }, + ) + .compose() + } + + async fn test_round_trip_uncomposed(component: &[u8]) -> Result<()> { + test_round_trip( + component, + "hello, world!", + "hello, world! - entered guest - entered host - exited host - exited guest", + ) + .await + } + + async fn test_round_trip_composed(a: &[u8], b: &[u8]) -> Result<()> { + test_round_trip( + &compose(a, b).await?, + "hello, world!", + "hello, world! - entered guest - entered guest - entered host \ + - exited host - exited guest - exited guest", + ) + .await + } + + #[tokio::test] + async fn async_round_trip_stackless() -> Result<()> { + test_round_trip_uncomposed( + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKLESS_COMPONENT).await?, + ) + .await + } + + #[tokio::test] + async fn async_round_trip_stackful() -> Result<()> { + test_round_trip_uncomposed( + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKFUL_COMPONENT).await?, + ) + .await + } + + #[tokio::test] + async fn async_round_trip_synchronous() -> Result<()> { + test_round_trip_uncomposed( + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_SYNCHRONOUS_COMPONENT).await?, + ) + .await + } + + #[tokio::test] + async fn async_round_trip_wait() -> Result<()> { + test_round_trip_uncomposed( + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_WAIT_COMPONENT).await?, + ) + .await + } + + #[tokio::test] + async fn async_round_trip_stackless_plus_stackless() -> Result<()> { + let stackless = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKLESS_COMPONENT).await?; + test_round_trip_composed(stackless, stackless).await + } + + #[tokio::test] + async fn async_round_trip_synchronous_plus_stackless() -> Result<()> { + let synchronous = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_SYNCHRONOUS_COMPONENT).await?; + let stackless = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKLESS_COMPONENT).await?; + test_round_trip_composed(synchronous, stackless).await + } + + #[tokio::test] + async fn async_round_trip_stackless_plus_synchronous() -> Result<()> { + let stackless = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKLESS_COMPONENT).await?; + let synchronous = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_SYNCHRONOUS_COMPONENT).await?; + test_round_trip_composed(stackless, synchronous).await + } + + #[tokio::test] + async fn async_round_trip_synchronous_plus_synchronous() -> Result<()> { + let synchronous = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_SYNCHRONOUS_COMPONENT).await?; + test_round_trip_composed(synchronous, synchronous).await + } + + #[tokio::test] + async fn async_round_trip_wait_plus_wait() -> Result<()> { + let wait = &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_WAIT_COMPONENT).await?; + test_round_trip_composed(wait, wait).await + } + + #[tokio::test] + async fn async_round_trip_synchronous_plus_wait() -> Result<()> { + let synchronous = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_SYNCHRONOUS_COMPONENT).await?; + let wait = &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_WAIT_COMPONENT).await?; + test_round_trip_composed(synchronous, wait).await + } + + #[tokio::test] + async fn async_round_trip_wait_plus_synchronous() -> Result<()> { + let wait = &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_WAIT_COMPONENT).await?; + let synchronous = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_SYNCHRONOUS_COMPONENT).await?; + test_round_trip_composed(wait, synchronous).await + } + + #[tokio::test] + async fn async_round_trip_stackless_plus_wait() -> Result<()> { + let stackless = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKLESS_COMPONENT).await?; + let wait = &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_WAIT_COMPONENT).await?; + test_round_trip_composed(stackless, wait).await + } + + #[tokio::test] + async fn async_round_trip_wait_plus_stackless() -> Result<()> { + let wait = &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_WAIT_COMPONENT).await?; + let stackless = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKLESS_COMPONENT).await?; + test_round_trip_composed(wait, stackless).await + } + + #[tokio::test] + async fn async_round_trip_stackful_plus_stackful() -> Result<()> { + let stackful = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKFUL_COMPONENT).await?; + test_round_trip_composed(stackful, stackful).await + } + + #[tokio::test] + async fn async_round_trip_stackful_plus_stackless() -> Result<()> { + let stackful = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKFUL_COMPONENT).await?; + let stackless = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKLESS_COMPONENT).await?; + test_round_trip_composed(stackful, stackless).await + } + + #[tokio::test] + async fn async_round_trip_stackless_plus_stackful() -> Result<()> { + let stackless = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKLESS_COMPONENT).await?; + let stackful = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKFUL_COMPONENT).await?; + test_round_trip_composed(stackless, stackful).await + } + + #[tokio::test] + async fn async_round_trip_synchronous_plus_stackful() -> Result<()> { + let synchronous = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_SYNCHRONOUS_COMPONENT).await?; + let stackful = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKFUL_COMPONENT).await?; + test_round_trip_composed(synchronous, stackful).await + } + + #[tokio::test] + async fn async_round_trip_stackful_plus_synchronous() -> Result<()> { + let stackful = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_STACKFUL_COMPONENT).await?; + let synchronous = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_SYNCHRONOUS_COMPONENT).await?; + test_round_trip_composed(stackful, synchronous).await + } + + mod round_trip_direct { + wasmtime::component::bindgen!({ + trappable_imports: true, + path: "wit", + world: "round-trip-direct", + concurrent_imports: true, + concurrent_exports: true, + async: true, + }); + } + + async fn test_round_trip_direct( + component: &[u8], + input: &str, + expected_output: &str, + ) -> Result<()> { + init_logger(); + + let mut config = Config::new(); + config.debug_info(true); + config.cranelift_debug_verifier(true); + config.wasm_component_model(true); + config.wasm_component_model_async(true); + config.async_support(true); + + let engine = Engine::new(&config)?; + + let make_store = || { + Store::new( + &engine, + Ctx { + wasi: WasiCtxBuilder::new().inherit_stdio().build(), + table: ResourceTable::default(), + continue_: false, + wakers: Arc::new(Mutex::new(None)), + }, + ) + }; + + let component = Component::new(&engine, component)?; + + // First, test the `wasmtime-wit-bindgen` static API: + { + let mut linker = Linker::new(&engine); + + wasmtime_wasi::add_to_linker_async(&mut linker)?; + round_trip_direct::RoundTripDirect::add_to_linker(&mut linker, |ctx| ctx)?; + + let mut store = make_store(); + + let round_trip = round_trip_direct::RoundTripDirect::instantiate_async( + &mut store, &component, &linker, + ) + .await?; + + // Start three concurrent calls and then join them all: + let mut promises = PromisesUnordered::new(); + for _ in 0..3 { + promises.push(round_trip.call_foo(&mut store, input.to_owned()).await?); + } + + while let Some(value) = promises.next(&mut store).await? { + assert_eq!(expected_output, &value); + } + } + + // Now do it again using the dynamic API (except for WASI, where we stick with the static API): + { + let mut linker = Linker::new(&engine); + + wasmtime_wasi::add_to_linker_async(&mut linker)?; + linker + .root() + .func_new_concurrent("foo", |_, params| async move { + tokio::time::sleep(Duration::from_millis(10)).await; + component::for_any(move |_: StoreContextMut<'_, Ctx>| { + let Some(Val::String(s)) = params.into_iter().next() else { + unreachable!() + }; + Ok(vec![Val::String(format!( + "{s} - entered host - exited host" + ))]) + }) + })?; + + let mut store = make_store(); + + let instance = linker.instantiate_async(&mut store, &component).await?; + let foo_function = instance + .get_export(&mut store, None, "foo") + .ok_or_else(|| anyhow!("can't find `foo` in instance"))?; + let foo_function = instance + .get_func(&mut store, foo_function) + .ok_or_else(|| anyhow!("can't find `foo` in instance"))?; + + // Start three concurrent calls and then join them all: + let mut promises = PromisesUnordered::new(); + for _ in 0..3 { + promises.push( + foo_function + .call_concurrent(&mut store, vec![Val::String(input.to_owned())]) + .await?, + ); + } + + while let Some(value) = promises.next(&mut store).await? { + let Some(Val::String(value)) = value.into_iter().next() else { + unreachable!() + }; + assert_eq!(expected_output, &value); + } + } + + Ok(()) + } + + async fn test_round_trip_direct_uncomposed(component: &[u8]) -> Result<()> { + test_round_trip_direct( + component, + "hello, world!", + "hello, world! - entered guest - entered host - exited host - exited guest", + ) + .await + } + + #[tokio::test] + async fn async_round_trip_direct_stackless() -> Result<()> { + let stackless = + &fs::read(test_programs_artifacts::ASYNC_ROUND_TRIP_DIRECT_STACKLESS_COMPONENT).await?; + test_round_trip_direct_uncomposed(stackless).await + } + + mod yield_host { + wasmtime::component::bindgen!({ + path: "wit", + world: "yield-host", + concurrent_imports: true, + concurrent_exports: true, + async: { + only_imports: [ + "local:local/ready#when-ready", + ] + }, + }); + } + + mod borrowing_host { + wasmtime::component::bindgen!({ + path: "wit", + world: "borrowing-host", + trappable_imports: true, + concurrent_imports: true, + concurrent_exports: true, + async: { + only_imports: [] + }, + with: { + "local:local/borrowing-types/x": super::MyX, + } + }); + } + + impl yield_host::local::local::continue_::Host for Ctx { + fn set_continue(&mut self, v: bool) { + self.continue_ = v; + } + + fn get_continue(&mut self) -> bool { + self.continue_ + } + } + + impl yield_host::local::local::ready::Host for Ctx { + type Data = Ctx; + + fn set_ready(&mut self, ready: bool) { + let mut wakers = self.wakers.lock().unwrap(); + if ready { + if let Some(wakers) = wakers.take() { + for waker in wakers { + waker.wake(); + } + } + } else if wakers.is_none() { + *wakers = Some(Vec::new()); + } + } + + fn when_ready( + store: StoreContextMut, + ) -> impl Future) + 'static> + + Send + + Sync + + 'static { + let wakers = store.data().wakers.clone(); + future::poll_fn(move |cx| { + let mut wakers = wakers.lock().unwrap(); + if let Some(wakers) = wakers.deref_mut() { + wakers.push(cx.waker().clone()); + Poll::Pending + } else { + Poll::Ready(component::for_any(|_| ())) + } + }) + } + } + + async fn test_run(component: &[u8]) -> Result<()> { + init_logger(); + + let mut config = Config::new(); + config.debug_info(true); + config.cranelift_debug_verifier(true); + config.wasm_component_model(true); + config.wasm_component_model_async(true); + config.async_support(true); + config.epoch_interruption(true); + + let engine = Engine::new(&config)?; + + let component = Component::new(&engine, component)?; + + let mut linker = Linker::new(&engine); + + wasmtime_wasi::add_to_linker_async(&mut linker)?; + yield_host::YieldHost::add_to_linker(&mut linker, |ctx| ctx)?; + + let mut store = Store::new( + &engine, + Ctx { + wasi: WasiCtxBuilder::new().inherit_stdio().build(), + table: ResourceTable::default(), + continue_: false, + wakers: Arc::new(Mutex::new(None)), + }, + ); + store.set_epoch_deadline(1); + + std::thread::spawn(move || { + std::thread::sleep(Duration::from_secs(10)); + engine.increment_epoch(); + }); + + let yield_host = + yield_host::YieldHost::instantiate_async(&mut store, &component, &linker).await?; + + // Start three concurrent calls and then join them all: + let mut promises = PromisesUnordered::new(); + for _ in 0..3 { + promises.push(yield_host.local_local_run().call_run(&mut store).await?); + } + + while let Some(()) = promises.next(&mut store).await? { + // continue + } + + Ok(()) + } + + // No-op function; we only test this by composing it in `async_yield_caller` + #[allow( + dead_code, + reason = "here only to make the `assert_test_exists` macro happy" + )] + fn async_yield_callee() {} + + #[tokio::test] + async fn async_yield_caller() -> Result<()> { + let caller = &fs::read(test_programs_artifacts::ASYNC_YIELD_CALLER_COMPONENT).await?; + let callee = &fs::read(test_programs_artifacts::ASYNC_YIELD_CALLEE_COMPONENT).await?; + test_run(&compose(caller, callee).await?).await + } + + #[tokio::test] + async fn async_poll() -> Result<()> { + test_run(&fs::read(test_programs_artifacts::ASYNC_POLL_COMPONENT).await?).await + } + + // No-op function; we only test this by composing it in `async_backpressure_caller` + #[allow( + dead_code, + reason = "here only to make the `assert_test_exists` macro happy" + )] + fn async_backpressure_callee() {} + + #[tokio::test] + async fn async_backpressure_caller() -> Result<()> { + let caller = + &fs::read(test_programs_artifacts::ASYNC_BACKPRESSURE_CALLER_COMPONENT).await?; + let callee = + &fs::read(test_programs_artifacts::ASYNC_BACKPRESSURE_CALLEE_COMPONENT).await?; + test_run(&compose(caller, callee).await?).await + } + + #[tokio::test] + async fn async_transmit_caller() -> Result<()> { + let caller = &fs::read(test_programs_artifacts::ASYNC_TRANSMIT_CALLER_COMPONENT).await?; + let callee = &fs::read(test_programs_artifacts::ASYNC_TRANSMIT_CALLEE_COMPONENT).await?; + test_run(&compose(caller, callee).await?).await + } + + // No-op function; we only test this by composing it in `async_post_return_caller` + #[allow( + dead_code, + reason = "here only to make the `assert_test_exists` macro happy" + )] + fn async_post_return_callee() {} + + #[tokio::test] + async fn async_post_return_caller() -> Result<()> { + let caller = &fs::read(test_programs_artifacts::ASYNC_POST_RETURN_CALLER_COMPONENT).await?; + let callee = &fs::read(test_programs_artifacts::ASYNC_POST_RETURN_CALLEE_COMPONENT).await?; + test_run(&compose(caller, callee).await?).await + } + + // No-op function; we only test this by composing it in `async_unit_stream_caller` + #[allow( + dead_code, + reason = "here only to make the `assert_test_exists` macro happy" + )] + fn async_unit_stream_callee() {} + + #[tokio::test] + async fn async_unit_stream_caller() -> Result<()> { + let caller = &fs::read(test_programs_artifacts::ASYNC_UNIT_STREAM_CALLER_COMPONENT).await?; + let callee = &fs::read(test_programs_artifacts::ASYNC_UNIT_STREAM_CALLEE_COMPONENT).await?; + test_run(&compose(caller, callee).await?).await + } + + async fn test_run_bool(component: &[u8], v: bool) -> Result<()> { + init_logger(); + + let mut config = Config::new(); + config.debug_info(true); + config.cranelift_debug_verifier(true); + config.wasm_component_model(true); + config.wasm_component_model_async(true); + config.async_support(true); + config.epoch_interruption(true); + + let engine = Engine::new(&config)?; + + let component = Component::new(&engine, component)?; + + let mut linker = Linker::new(&engine); + + wasmtime_wasi::add_to_linker_async(&mut linker)?; + borrowing_host::BorrowingHost::add_to_linker(&mut linker, |ctx| ctx)?; + + let mut store = Store::new( + &engine, + Ctx { + wasi: WasiCtxBuilder::new().inherit_stdio().build(), + table: ResourceTable::default(), + continue_: false, + wakers: Arc::new(Mutex::new(None)), + }, + ); + store.set_epoch_deadline(1); + + std::thread::spawn(move || { + std::thread::sleep(Duration::from_secs(10)); + engine.increment_epoch(); + }); + + let borrowing_host = + borrowing_host::BorrowingHost::instantiate_async(&mut store, &component, &linker) + .await?; + + // Start three concurrent calls and then join them all: + let mut promises = PromisesUnordered::new(); + for _ in 0..3 { + promises.push( + borrowing_host + .local_local_run_bool() + .call_run(&mut store, v) + .await?, + ); + } + + while let Some(()) = promises.next(&mut store).await? { + // continue + } + + Ok(()) + } + + #[tokio::test] + async fn async_borrowing_caller() -> Result<()> { + let caller = &fs::read(test_programs_artifacts::ASYNC_BORROWING_CALLER_COMPONENT).await?; + let callee = &fs::read(test_programs_artifacts::ASYNC_BORROWING_CALLEE_COMPONENT).await?; + test_run_bool(&compose(caller, callee).await?, false).await + } + + #[tokio::test] + async fn async_borrowing_caller_misbehave() -> Result<()> { + let caller = &fs::read(test_programs_artifacts::ASYNC_BORROWING_CALLER_COMPONENT).await?; + let callee = &fs::read(test_programs_artifacts::ASYNC_BORROWING_CALLEE_COMPONENT).await?; + let error = format!( + "{:?}", + test_run_bool(&compose(caller, callee).await?, true) + .await + .unwrap_err() + ); + assert!(error.contains("unknown handle index"), "{error}"); + Ok(()) + } + + #[tokio::test] + async fn async_borrowing_callee() -> Result<()> { + let callee = &fs::read(test_programs_artifacts::ASYNC_BORROWING_CALLEE_COMPONENT).await?; + test_run_bool(callee, false).await + } + + #[tokio::test] + async fn async_borrowing_callee_misbehave() -> Result<()> { + let callee = &fs::read(test_programs_artifacts::ASYNC_BORROWING_CALLEE_COMPONENT).await?; + let error = format!("{:?}", test_run_bool(callee, true).await.unwrap_err()); + assert!(error.contains("unknown handle index"), "{error}"); + Ok(()) + } + + mod transmit { + wasmtime::component::bindgen!({ + path: "wit", + world: "transmit-callee", + concurrent_exports: true, + async: true, + }); + } + + trait TransmitTest { + type Instance; + type Params; + type Result; + + async fn instantiate( + store: impl AsContextMut, + component: &Component, + linker: &Linker, + ) -> Result; + + async fn call( + store: impl AsContextMut, + instance: &Self::Instance, + params: Self::Params, + ) -> Result>; + + fn into_params( + control: StreamReader, + caller_stream: StreamReader, + caller_future1: FutureReader, + caller_future2: FutureReader, + ) -> Self::Params; + + fn from_result( + store: impl AsContextMut, + result: Self::Result, + ) -> Result<( + StreamReader, + FutureReader, + FutureReader, + )>; + } + + struct StaticTransmitTest; + + impl TransmitTest for StaticTransmitTest { + type Instance = transmit::TransmitCallee; + type Params = ( + StreamReader, + StreamReader, + FutureReader, + FutureReader, + ); + type Result = ( + StreamReader, + FutureReader, + FutureReader, + ); + + async fn instantiate( + store: impl AsContextMut, + component: &Component, + linker: &Linker, + ) -> Result { + transmit::TransmitCallee::instantiate_async(store, component, linker).await + } + + async fn call( + store: impl AsContextMut, + instance: &Self::Instance, + params: Self::Params, + ) -> Result> { + instance + .local_local_transmit() + .call_exchange(store, params.0, params.1, params.2, params.3) + .await + } + + fn into_params( + control: StreamReader, + caller_stream: StreamReader, + caller_future1: FutureReader, + caller_future2: FutureReader, + ) -> Self::Params { + (control, caller_stream, caller_future1, caller_future2) + } + + fn from_result( + _: impl AsContextMut, + result: Self::Result, + ) -> Result<( + StreamReader, + FutureReader, + FutureReader, + )> { + Ok(result) + } + } + + struct DynamicTransmitTest; + + impl TransmitTest for DynamicTransmitTest { + type Instance = Instance; + type Params = Vec; + type Result = Val; + + async fn instantiate( + store: impl AsContextMut, + component: &Component, + linker: &Linker, + ) -> Result { + linker.instantiate_async(store, component).await + } + + async fn call( + mut store: impl AsContextMut, + instance: &Self::Instance, + params: Self::Params, + ) -> Result> { + let transmit_instance = instance + .get_export(store.as_context_mut(), None, "local:local/transmit") + .ok_or_else(|| anyhow!("can't find `local:local/transmit` in instance"))?; + let exchange_function = instance + .get_export(store.as_context_mut(), Some(&transmit_instance), "exchange") + .ok_or_else(|| anyhow!("can't find `exchange` in instance"))?; + let exchange_function = instance + .get_func(store.as_context_mut(), exchange_function) + .ok_or_else(|| anyhow!("can't find `exchange` in instance"))?; + + Ok(exchange_function + .call_concurrent(store, params) + .await? + .map(|results| results.into_iter().next().unwrap())) + } + + fn into_params( + control: StreamReader, + caller_stream: StreamReader, + caller_future1: FutureReader, + caller_future2: FutureReader, + ) -> Self::Params { + vec![ + control.into_val(), + caller_stream.into_val(), + caller_future1.into_val(), + caller_future2.into_val(), + ] + } + + fn from_result( + mut store: impl AsContextMut, + result: Self::Result, + ) -> Result<( + StreamReader, + FutureReader, + FutureReader, + )> { + let Val::Tuple(fields) = result else { + unreachable!() + }; + let stream = StreamReader::from_val(store.as_context_mut(), &fields[0])?; + let future1 = FutureReader::from_val(store.as_context_mut(), &fields[1])?; + let future2 = FutureReader::from_val(store.as_context_mut(), &fields[2])?; + Ok((stream, future1, future2)) + } + } + + async fn test_transmit(component: &[u8]) -> Result<()> { + init_logger(); + + test_transmit_with::(component).await?; + test_transmit_with::(component).await + } + + async fn test_transmit_with(component: &[u8]) -> Result<()> { + let mut config = Config::new(); + config.debug_info(true); + config.cranelift_debug_verifier(true); + config.wasm_component_model(true); + config.wasm_component_model_async(true); + config.async_support(true); + + let engine = Engine::new(&config)?; + + let make_store = || { + Store::new( + &engine, + Ctx { + wasi: WasiCtxBuilder::new().inherit_stdio().build(), + table: ResourceTable::default(), + continue_: false, + wakers: Arc::new(Mutex::new(None)), + }, + ) + }; + + let component = Component::new(&engine, component)?; + + let mut linker = Linker::new(&engine); + + wasmtime_wasi::add_to_linker_async(&mut linker)?; + + let mut store = make_store(); + + let instance = Test::instantiate(&mut store, &component, &linker).await?; + + enum Event { + Result(Test::Result), + ControlWriteA(StreamWriter), + ControlWriteB(StreamWriter), + ControlWriteC(StreamWriter), + ControlWriteD(StreamWriter), + WriteA(StreamWriter), + WriteB, + ReadC(Option<(StreamReader, Vec)>), + ReadD(Option), + ReadNone(Option<(StreamReader, Vec)>), + } + + let (control_tx, control_rx) = component::stream(&mut store)?; + let (caller_stream_tx, caller_stream_rx) = component::stream(&mut store)?; + let (caller_future1_tx, caller_future1_rx) = component::future(&mut store)?; + let (_caller_future2_tx, caller_future2_rx) = component::future(&mut store)?; + + let mut promises = PromisesUnordered::>::new(); + let mut caller_future1_tx = Some(caller_future1_tx); + let mut callee_stream_rx = None; + let mut callee_future1_rx = None; + let mut complete = false; + + promises.push( + control_tx + .write(&mut store, vec![Control::ReadStream("a".into())])? + .map(Event::ControlWriteA), + ); + + promises.push( + caller_stream_tx + .write(&mut store, vec!["a".into()])? + .map(Event::WriteA), + ); + + promises.push( + Test::call( + &mut store, + &instance, + Test::into_params( + control_rx, + caller_stream_rx, + caller_future1_rx, + caller_future2_rx, + ), + ) + .await? + .map(Event::Result), + ); + + while let Some(event) = promises.next(&mut store).await? { + match event { + Event::Result(result) => { + let results = Test::from_result(&mut store, result)?; + callee_stream_rx = Some(results.0); + callee_future1_rx = Some(results.1); + results.2.close(&mut store)?; + } + Event::ControlWriteA(tx) => { + promises.push( + tx.write(&mut store, vec![Control::ReadFuture("b".into())])? + .map(Event::ControlWriteB), + ); + } + Event::WriteA(tx) => { + tx.close(&mut store)?; + promises.push( + caller_future1_tx + .take() + .unwrap() + .write(&mut store, "b".into())? + .map(|()| Event::WriteB), + ); + } + Event::ControlWriteB(tx) => { + promises.push( + tx.write(&mut store, vec![Control::WriteStream("c".into())])? + .map(Event::ControlWriteC), + ); + } + Event::WriteB => { + promises.push( + callee_stream_rx + .take() + .unwrap() + .read(&mut store)? + .map(Event::ReadC), + ); + } + Event::ControlWriteC(tx) => { + promises.push( + tx.write(&mut store, vec![Control::WriteFuture("d".into())])? + .map(Event::ControlWriteD), + ); + } + Event::ReadC(None) => unreachable!(), + Event::ReadC(Some((rx, values))) => { + assert_eq!("c", &values[0]); + promises.push( + callee_future1_rx + .take() + .unwrap() + .read(&mut store)? + .map(Event::ReadD), + ); + callee_stream_rx = Some(rx); + } + Event::ControlWriteD(tx) => { + tx.close(&mut store)?; + } + Event::ReadD(None) => unreachable!(), + Event::ReadD(Some(value)) => { + assert_eq!("d", &value); + promises.push( + callee_stream_rx + .take() + .unwrap() + .read(&mut store)? + .map(Event::ReadNone), + ); + } + Event::ReadNone(Some(_)) => unreachable!(), + Event::ReadNone(None) => { + complete = true; + } + } + } + + assert!(complete); + + Ok(()) + } + + #[tokio::test] + async fn async_transmit_callee() -> Result<()> { + test_transmit(&fs::read(test_programs_artifacts::ASYNC_TRANSMIT_CALLEE_COMPONENT).await?) + .await + } + + mod proxy { + wasmtime::component::bindgen!({ + path: "wit", + world: "wasi:http/proxy", + concurrent_imports: true, + concurrent_exports: true, + async: { + only_imports: [ + "wasi:http/types@0.3.0-draft#[static]body.finish", + "wasi:http/handler@0.3.0-draft#handle", + ] + }, + with: { + "wasi:http/types": wasi_http_draft::wasi::http::types, + } + }); + } + + impl WasiHttpView for Ctx { + type Data = Ctx; + + fn table(&mut self) -> &mut ResourceTable { + &mut self.table + } + + #[allow(clippy::manual_async_fn)] + fn send_request( + _store: StoreContextMut<'_, Self::Data>, + _request: Resource, + ) -> impl Future< + Output = impl FnOnce( + StoreContextMut<'_, Self::Data>, + ) + -> wasmtime::Result, ErrorCode>> + + 'static, + > + Send + + 'static { + async move { + move |_: StoreContextMut<'_, Self>| { + Err(anyhow!("no outbound request handler available")) + } + } + } + } + + async fn test_http_echo(component: &[u8], use_compression: bool) -> Result<()> { + use { + flate2::{ + write::{DeflateDecoder, DeflateEncoder}, + Compression, + }, + std::io::Write, + }; + + init_logger(); + + let mut config = Config::new(); + config.cranelift_debug_verifier(true); + config.wasm_component_model(true); + config.wasm_component_model_async(true); + config.async_support(true); + + let engine = Engine::new(&config)?; + + let component = Component::new(&engine, component)?; + + let mut linker = Linker::new(&engine); + + wasmtime_wasi::add_to_linker_async(&mut linker)?; + wasi_http_draft::add_to_linker(&mut linker)?; + + let mut store = Store::new( + &engine, + Ctx { + wasi: WasiCtxBuilder::new().inherit_stdio().build(), + table: ResourceTable::default(), + continue_: false, + wakers: Arc::new(Mutex::new(None)), + }, + ); + + let proxy = proxy::Proxy::instantiate_async(&mut store, &component, &linker).await?; + + let headers = [("foo".into(), b"bar".into())]; + + let body = b"And the mome raths outgrabe"; + + enum Event { + RequestBodyWrite(StreamWriter), + RequestTrailersWrite, + Response(Result, ErrorCode>), + ResponseBodyRead(Option<(StreamReader, Vec)>), + ResponseTrailersRead(Option>), + } + + let mut promises = PromisesUnordered::new(); + + let (request_body_tx, request_body_rx) = component::stream(&mut store)?; + + promises.push( + request_body_tx + .write( + &mut store, + if use_compression { + let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast()); + encoder.write_all(body)?; + encoder.finish()? + } else { + body.to_vec() + }, + )? + .map(Event::RequestBodyWrite), + ); + + let trailers = vec![("fizz".into(), b"buzz".into())]; + + let (request_trailers_tx, request_trailers_rx) = component::future(&mut store)?; + + let request_trailers = IoView::table(store.data_mut()).push(Fields(trailers.clone()))?; + + promises.push( + request_trailers_tx + .write(&mut store, request_trailers)? + .map(|()| Event::RequestTrailersWrite), + ); + + let request = IoView::table(store.data_mut()).push(Request { + method: Method::Post, + scheme: Some(Scheme::Http), + path_with_query: Some("/".into()), + authority: Some("localhost".into()), + headers: Fields( + headers + .iter() + .cloned() + .chain(if use_compression { + vec![ + ("content-encoding".into(), b"deflate".into()), + ("accept-encoding".into(), b"deflate".into()), + ] + } else { + Vec::new() + }) + .collect(), + ), + body: Body { + stream: Some(request_body_rx), + trailers: Some(request_trailers_rx), + }, + options: None, + })?; + + promises.push( + proxy + .wasi_http_handler() + .call_handle(&mut store, request) + .await? + .map(Event::Response), + ); + + let mut response_body = Vec::new(); + let mut response_trailers = None; + let mut received_trailers = false; + while let Some(event) = promises.next(&mut store).await? { + match event { + Event::RequestBodyWrite(tx) => tx.close(&mut store)?, + Event::RequestTrailersWrite => {} + Event::Response(response) => { + let mut response = IoView::table(store.data_mut()).delete(response?)?; + + assert!(response.status_code == 200); + + assert!(headers.iter().all(|(k0, v0)| response + .headers + .0 + .iter() + .any(|(k1, v1)| k0 == k1 && v0 == v1))); + + if use_compression { + assert!(response.headers.0.iter().any(|(k, v)| matches!( + (k.as_str(), v.as_slice()), + ("content-encoding", b"deflate") + ))); + } + + response_trailers = response.body.trailers.take(); + + promises.push( + response + .body + .stream + .take() + .unwrap() + .read(&mut store)? + .map(Event::ResponseBodyRead), + ); + } + Event::ResponseBodyRead(Some((rx, chunk))) => { + response_body.extend(chunk); + promises.push(rx.read(&mut store)?.map(Event::ResponseBodyRead)); + } + Event::ResponseBodyRead(None) => { + let response_body = if use_compression { + let mut decoder = DeflateDecoder::new(Vec::new()); + decoder.write_all(&response_body)?; + decoder.finish()? + } else { + response_body.clone() + }; + + assert_eq!(body as &[_], &response_body); + + promises.push( + response_trailers + .take() + .unwrap() + .read(&mut store)? + .map(Event::ResponseTrailersRead), + ); + } + Event::ResponseTrailersRead(Some(response_trailers)) => { + let response_trailers = + IoView::table(store.data_mut()).delete(response_trailers)?; + + assert!(trailers.iter().all(|(k0, v0)| response_trailers + .0 + .iter() + .any(|(k1, v1)| k0 == k1 && v0 == v1))); + + received_trailers = true; + } + Event::ResponseTrailersRead(None) => panic!("expected response trailers; got none"), + } + } + + assert!(received_trailers); + + Ok(()) + } + + #[tokio::test] + async fn async_http_echo() -> Result<()> { + test_http_echo( + &fs::read(test_programs_artifacts::ASYNC_HTTP_ECHO_COMPONENT).await?, + false, + ) + .await + } + + #[tokio::test] + async fn async_http_middleware() -> Result<()> { + let echo = &fs::read(test_programs_artifacts::ASYNC_HTTP_ECHO_COMPONENT).await?; + let middleware = + &fs::read(test_programs_artifacts::ASYNC_HTTP_MIDDLEWARE_COMPONENT).await?; + test_http_echo(&compose(middleware, echo).await?, true).await + } + + #[tokio::test] + async fn async_error_context() -> Result<()> { + test_run(&fs::read(test_programs_artifacts::ASYNC_ERROR_CONTEXT_COMPONENT).await?).await + } + + #[tokio::test] + async fn async_error_context_callee() -> Result<()> { + test_run(&fs::read(test_programs_artifacts::ASYNC_ERROR_CONTEXT_COMPONENT).await?).await + } + + #[tokio::test] + async fn async_error_context_caller() -> Result<()> { + let caller = + &fs::read(test_programs_artifacts::ASYNC_ERROR_CONTEXT_CALLER_COMPONENT).await?; + let callee = + &fs::read(test_programs_artifacts::ASYNC_ERROR_CONTEXT_CALLEE_COMPONENT).await?; + test_run(&compose(caller, callee).await?).await + } +} diff --git a/crates/misc/component-async-tests/wit/deps/http/handler.wit b/crates/misc/component-async-tests/wit/deps/http/handler.wit new file mode 100644 index 000000000000..bfe459f40b26 --- /dev/null +++ b/crates/misc/component-async-tests/wit/deps/http/handler.wit @@ -0,0 +1,17 @@ +// This interface defines a handler of HTTP Requests. It may be imported by +/// components which wish to send HTTP Requests and also exported by components +/// which can respond to HTTP Requests. In addition, it may be used to pass +/// a request from one component to another without any use of a network. +interface handler { + use types.{request, response, error-code}; + + /// When exported, this function may be called with either an incoming + /// request read from the network or a request synthesized or forwarded by + /// another component. + /// + /// When imported, this function may be used to either send an outgoing + /// request over the network or pass it to another component. + handle: func( + request: request, + ) -> result; +} diff --git a/crates/misc/component-async-tests/wit/deps/http/proxy.wit b/crates/misc/component-async-tests/wit/deps/http/proxy.wit new file mode 100644 index 000000000000..efb3952134a7 --- /dev/null +++ b/crates/misc/component-async-tests/wit/deps/http/proxy.wit @@ -0,0 +1,6 @@ +package wasi:http@0.3.0-draft; + +world proxy { + import handler; + export handler; +} diff --git a/crates/misc/component-async-tests/wit/deps/http/types.wit b/crates/misc/component-async-tests/wit/deps/http/types.wit new file mode 100644 index 000000000000..4c5bd4c4eef2 --- /dev/null +++ b/crates/misc/component-async-tests/wit/deps/http/types.wit @@ -0,0 +1,424 @@ +/// This interface defines all of the types and methods for implementing HTTP +/// Requests and Responses, as well as their headers, trailers, and bodies. +interface types { + type duration = u64; + + /// This type corresponds to HTTP standard Methods. + variant method { + get, + head, + post, + put, + delete, + connect, + options, + trace, + patch, + other(string) + } + + /// This type corresponds to HTTP standard Related Schemes. + variant scheme { + HTTP, + HTTPS, + other(string) + } + + /// These cases are inspired by the IANA HTTP Proxy Error Types: + /// https://www.iana.org/assignments/http-proxy-status/http-proxy-status.xhtml#table-http-proxy-error-types + variant error-code { + DNS-timeout, + DNS-error(DNS-error-payload), + destination-not-found, + destination-unavailable, + destination-IP-prohibited, + destination-IP-unroutable, + connection-refused, + connection-terminated, + connection-timeout, + connection-read-timeout, + connection-write-timeout, + connection-limit-reached, + TLS-protocol-error, + TLS-certificate-error, + TLS-alert-received(TLS-alert-received-payload), + HTTP-request-denied, + HTTP-request-length-required, + HTTP-request-body-size(option), + HTTP-request-method-invalid, + HTTP-request-URI-invalid, + HTTP-request-URI-too-long, + HTTP-request-header-section-size(option), + HTTP-request-header-size(option), + HTTP-request-trailer-section-size(option), + HTTP-request-trailer-size(field-size-payload), + HTTP-response-incomplete, + HTTP-response-header-section-size(option), + HTTP-response-header-size(field-size-payload), + HTTP-response-body-size(option), + HTTP-response-trailer-section-size(option), + HTTP-response-trailer-size(field-size-payload), + HTTP-response-transfer-coding(option), + HTTP-response-content-coding(option), + HTTP-response-timeout, + HTTP-upgrade-failed, + HTTP-protocol-error, + loop-detected, + configuration-error, + /// This is a catch-all error for anything that doesn't fit cleanly into a + /// more specific case. It also includes an optional string for an + /// unstructured description of the error. Users should not depend on the + /// string for diagnosing errors, as it's not required to be consistent + /// between implementations. + internal-error(option) + } + + /// Defines the case payload type for `DNS-error` above: + record DNS-error-payload { + rcode: option, + info-code: option + } + + /// Defines the case payload type for `TLS-alert-received` above: + record TLS-alert-received-payload { + alert-id: option, + alert-message: option + } + + /// Defines the case payload type for `HTTP-response-{header,trailer}-size` above: + record field-size-payload { + field-name: option, + field-size: option + } + + /// Attempts to extract a http-related `error-code` from the stream `error` + /// provided. + /// + /// Stream operations may fail with a stream `error` with more information + /// about the operation that failed. This `error` can be passed to this + /// function to see if there's http-related information about the error to + /// return. + /// + /// Note that this function is fallible because not all stream errors are + /// http-related errors. + http-error-code: func(err: error-context) -> option; + + /// This type enumerates the different kinds of errors that may occur when + /// setting or appending to a `fields` resource. + variant header-error { + /// This error indicates that a `field-key` or `field-value` was + /// syntactically invalid when used with an operation that sets headers in a + /// `fields`. + invalid-syntax, + + /// This error indicates that a forbidden `field-key` was used when trying + /// to set a header in a `fields`. + forbidden, + + /// This error indicates that the operation on the `fields` was not + /// permitted because the fields are immutable. + immutable, + } + + /// This type enumerates the different kinds of errors that may occur when + /// setting fields of a `request-options` resource. + variant request-options-error { + /// Indicates the specified field is not supported by this implementation. + not-supported, + + /// Indicates that the operation on the `request-options` was not permitted + /// because it is immutable. + immutable, + } + + /// Field keys are always strings. + type field-key = string; + + /// Field values should always be ASCII strings. However, in + /// reality, HTTP implementations often have to interpret malformed values, + /// so they are provided as a list of bytes. + type field-value = list; + + /// This following block defines the `fields` resource which corresponds to + /// HTTP standard Fields. Fields are a common representation used for both + /// Headers and Trailers. + /// + /// A `fields` may be mutable or immutable. A `fields` created using the + /// constructor, `from-list`, or `clone` will be mutable, but a `fields` + /// resource given by other means (including, but not limited to, + /// `request.headers`) might be be immutable. In an immutable fields, the + /// `set`, `append`, and `delete` operations will fail with + /// `header-error.immutable`. + resource fields { + + /// Construct an empty HTTP Fields. + /// + /// The resulting `fields` is mutable. + constructor(); + + /// Construct an HTTP Fields. + /// + /// The resulting `fields` is mutable. + /// + /// The list represents each key-value pair in the Fields. Keys + /// which have multiple values are represented by multiple entries in this + /// list with the same key. + /// + /// The tuple is a pair of the field key, represented as a string, and + /// Value, represented as a list of bytes. In a valid Fields, all keys + /// and values are valid UTF-8 strings. However, values are not always + /// well-formed, so they are represented as a raw list of bytes. + /// + /// An error result will be returned if any header or value was + /// syntactically invalid, or if a header was forbidden. + from-list: static func( + entries: list> + ) -> result; + + /// Get all of the values corresponding to a key. If the key is not present + /// in this `fields`, an empty list is returned. However, if the key is + /// present but empty, this is represented by a list with one or more + /// empty field-values present. + get: func(name: field-key) -> list; + + /// Returns `true` when the key is present in this `fields`. If the key is + /// syntactically invalid, `false` is returned. + has: func(name: field-key) -> bool; + + /// Set all of the values for a key. Clears any existing values for that + /// key, if they have been set. + /// + /// Fails with `header-error.immutable` if the `fields` are immutable. + set: func(name: field-key, value: list) -> result<_, header-error>; + + /// Delete all values for a key. Does nothing if no values for the key + /// exist. + /// + /// Returns any values previously corresponding to the key. + /// + /// Fails with `header-error.immutable` if the `fields` are immutable. + delete: func(name: field-key) -> result, header-error>; + + /// Append a value for a key. Does not change or delete any existing + /// values for that key. + /// + /// Fails with `header-error.immutable` if the `fields` are immutable. + append: func(name: field-key, value: field-value) -> result<_, header-error>; + + /// Retrieve the full set of keys and values in the Fields. Like the + /// constructor, the list represents each key-value pair. + /// + /// The outer list represents each key-value pair in the Fields. Keys + /// which have multiple values are represented by multiple entries in this + /// list with the same key. + entries: func() -> list>; + + /// Make a deep copy of the Fields. Equivelant in behavior to calling the + /// `fields` constructor on the return value of `entries`. The resulting + /// `fields` is mutable. + clone: func() -> fields; + } + + /// Headers is an alias for Fields. + type headers = fields; + + /// Trailers is an alias for Fields. + type trailers = fields; + + /// Represents an HTTP Request or Response's Body. + /// + /// A body has both its contents - a stream of bytes - and a (possibly empty) + /// set of trailers, indicating that the full contents of the body have been + /// received. This resource represents the contents as a `stream` and the + /// delivery of trailers as a `trailers`, and ensures that the user of this + /// interface may only be consuming either the body contents or waiting on + /// trailers at any given time. + resource body { + + /// Construct a new `body` with the specified stream and trailers. + constructor( + %stream: stream, + trailers: option> + ); + + /// Returns the contents of the body, as a stream of bytes. + /// + /// This function may be called multiple times as long as any `stream`s + /// returned by previous calls have been dropped first. + %stream: func() -> result>; + + /// Takes ownership of `body`, and returns a `trailers`. This function will + /// trap if a `stream` child is still alive. + finish: static func(this: body) -> result, error-code>; + } + + /// Represents an HTTP Request. + resource request { + + /// Construct a new `request` with a default `method` of `GET`, and + /// `none` values for `path-with-query`, `scheme`, and `authority`. + /// + /// * `headers` is the HTTP Headers for the Response. + /// * `body` is the contents of the body, as a stream of bytes. + /// * `trailers` is an optional `future` which resolves to the HTTP Trailers + /// for the Response. + /// * `options` is optional `request-options` to be used if the request is + /// sent over a network connection. + /// + /// It is possible to construct, or manipulate with the accessor functions + /// below, an `request` with an invalid combination of `scheme` + /// and `authority`, or `headers` which are not permitted to be sent. + /// It is the obligation of the `handler.handle` implementation + /// to reject invalid constructions of `request`. + constructor( + headers: headers, + body: body, + options: option + ); + + /// Get the Method for the Request. + method: func() -> method; + /// Set the Method for the Request. Fails if the string present in a + /// `method.other` argument is not a syntactically valid method. + set-method: func(method: method) -> result; + + /// Get the combination of the HTTP Path and Query for the Request. When + /// `none`, this represents an empty Path and empty Query. + path-with-query: func() -> option; + /// Set the combination of the HTTP Path and Query for the Request. When + /// `none`, this represents an empty Path and empty Query. Fails is the + /// string given is not a syntactically valid path and query uri component. + set-path-with-query: func(path-with-query: option) -> result; + + /// Get the HTTP Related Scheme for the Request. When `none`, the + /// implementation may choose an appropriate default scheme. + scheme: func() -> option; + /// Set the HTTP Related Scheme for the Request. When `none`, the + /// implementation may choose an appropriate default scheme. Fails if the + /// string given is not a syntactically valid uri scheme. + set-scheme: func(scheme: option) -> result; + + /// Get the HTTP Authority for the Request. A value of `none` may be used + /// with Related Schemes which do not require an Authority. The HTTP and + /// HTTPS schemes always require an authority. + authority: func() -> option; + /// Set the HTTP Authority for the Request. A value of `none` may be used + /// with Related Schemes which do not require an Authority. The HTTP and + /// HTTPS schemes always require an authority. Fails if the string given is + /// not a syntactically valid uri authority. + set-authority: func(authority: option) -> result; + + /// Get the `request-options` to be associated with this request + /// + /// The returned `request-options` resource is immutable: `set-*` operations + /// will fail if invoked. + /// + /// This `request-options` resource is a child: it must be dropped before + /// the parent `request` is dropped, or its ownership is transfered to + /// another component by e.g. `handler.handle`. + options: func() -> option; + + /// Get the headers associated with the Request. + /// + /// The returned `headers` resource is immutable: `set`, `append`, and + /// `delete` operations will fail with `header-error.immutable`. + /// + /// This headers resource is a child: it must be dropped before the parent + /// `request` is dropped, or its ownership is transfered to another + /// component by e.g. `handler.handle`. + headers: func() -> headers; + + /// Get the body associated with the Request. + /// + /// This body resource is a child: it must be dropped before the parent + /// `request` is dropped, or its ownership is transfered to another + /// component by e.g. `handler.handle`. + body: func() -> body; + + /// Takes ownership of the `request` and returns the `headers` and `body`. + into-parts: static func(this: request) -> tuple; + } + + /// Parameters for making an HTTP Request. Each of these parameters is + /// currently an optional timeout applicable to the transport layer of the + /// HTTP protocol. + /// + /// These timeouts are separate from any the user may use to bound an + /// asynchronous call. + resource request-options { + /// Construct a default `request-options` value. + constructor(); + + /// The timeout for the initial connect to the HTTP Server. + connect-timeout: func() -> option; + + /// Set the timeout for the initial connect to the HTTP Server. An error + /// return value indicates that this timeout is not supported or that this + /// handle is immutable. + set-connect-timeout: func(duration: option) -> result<_, request-options-error>; + + /// The timeout for receiving the first byte of the Response body. + first-byte-timeout: func() -> option; + + /// Set the timeout for receiving the first byte of the Response body. An + /// error return value indicates that this timeout is not supported or that + /// this handle is immutable. + set-first-byte-timeout: func(duration: option) -> result<_, request-options-error>; + + /// The timeout for receiving subsequent chunks of bytes in the Response + /// body stream. + between-bytes-timeout: func() -> option; + + /// Set the timeout for receiving subsequent chunks of bytes in the Response + /// body stream. An error return value indicates that this timeout is not + /// supported or that this handle is immutable. + set-between-bytes-timeout: func(duration: option) -> result<_, request-options-error>; + } + + /// This type corresponds to the HTTP standard Status Code. + type status-code = u16; + + /// Represents an HTTP Response. + resource response { + + /// Construct an `response`, with a default `status-code` of `200`. If a + /// different `status-code` is needed, it must be set via the + /// `set-status-code` method. + /// + /// * `headers` is the HTTP Headers for the Response. + /// * `body` is the contents of the body, as a stream of bytes. + /// * `trailers` is an optional `future` which resolves to the HTTP Trailers + /// for the Response. + constructor( + headers: headers, + body: body, + ); + + /// Get the HTTP Status Code for the Response. + status-code: func() -> status-code; + + /// Set the HTTP Status Code for the Response. Fails if the status-code + /// given is not a valid http status code. + set-status-code: func(status-code: status-code) -> result; + + /// Get the headers associated with the Request. + /// + /// The returned `headers` resource is immutable: `set`, `append`, and + /// `delete` operations will fail with `header-error.immutable`. + /// + /// This headers resource is a child: it must be dropped before the parent + /// `response` is dropped, or its ownership is transfered to another + /// component by e.g. `handler.handle`. + headers: func() -> headers; + + /// Get the body associated with the Response. + /// + /// This body resource is a child: it must be dropped before the parent + /// `response` is dropped, or its ownership is transfered to another + /// component by e.g. `handler.handle`. + body: func() -> body; + + /// Takes ownership of the `response` and returns the `headers` and `body`. + into-parts: static func(this: response) -> tuple; + } +} diff --git a/crates/misc/component-async-tests/wit/test.wit b/crates/misc/component-async-tests/wit/test.wit new file mode 100644 index 000000000000..fbbc722d98d3 --- /dev/null +++ b/crates/misc/component-async-tests/wit/test.wit @@ -0,0 +1,168 @@ +package local:local; + +interface baz { + foo: func(s: string) -> string; +} + +world round-trip { + import baz; + export baz; +} + +world round-trip-direct { + import foo: func(s: string) -> string; + export foo: func(s: string) -> string; +} + +interface ready { + set-ready: func(ready: bool); + when-ready: func(); +} + +interface continue { + set-continue: func(continue: bool); + get-continue: func() -> bool; +} + +interface run { + run: func(); +} + +interface backpressure { + set-backpressure: func(enabled: bool); +} + +interface transmit { + variant control { + read-stream(string), + read-future(string), + write-stream(string), + write-future(string), + } + + exchange: func(control: stream, + caller-stream: stream, + caller-future1: future, + caller-future2: future) -> tuple, future, future>; +} + +interface post-return { + foo: func(s: string) -> string; + get-post-return-value: func() -> string; +} + +interface borrowing-types { + resource x { + constructor(); + foo: func(); + } +} + +interface borrowing { + use borrowing-types.{x}; + + foo: func(x: borrow, misbehave: bool); +} + +interface run-bool { + run: func(v: bool); +} + +interface run-result { + run-fail: func() -> result<_, error-context>; + run-pass: func() -> result<_, error-context>; +} + +interface unit-stream { + run: func(count: u32) -> stream; +} + +world yield-caller { + import continue; + import ready; + import run; + export run; +} + +world yield-callee { + import continue; + export run; +} + +world yield-host { + import continue; + import ready; + export run; +} + +world poll { + import ready; + export run; +} + +world backpressure-caller { + import backpressure; + import run; + export run; +} + +world backpressure-callee { + export backpressure; + export run; +} + +world transmit-caller { + import transmit; + export run; +} + +world transmit-callee { + export transmit; +} + +world post-return-caller { + import post-return; + export run; +} + +world post-return-callee { + export post-return; +} + +world borrowing-caller { + import borrowing; + export run-bool; +} + +world borrowing-callee { + export borrowing; + export run-bool; +} + +world borrowing-host { + import borrowing-types; + export run-bool; +} + +world error-context-usage { + export run; +} + +world error-context-callee { + export run-result; + export run; +} + +world error-context-caller { + import run-result; + export run; +} + +world unit-stream-caller { + import unit-stream; + export run; +} + +world unit-stream-callee { + export unit-stream; +} diff --git a/crates/misc/component-test-util/src/lib.rs b/crates/misc/component-test-util/src/lib.rs index 2a6e72efb5e6..a04ed65bb312 100644 --- a/crates/misc/component-test-util/src/lib.rs +++ b/crates/misc/component-test-util/src/lib.rs @@ -8,15 +8,23 @@ use wasmtime::component::{ComponentNamedList, ComponentType, Func, Lift, Lower, use wasmtime::{AsContextMut, Config, Engine}; pub trait TypedFuncExt { - fn call_and_post_return(&self, store: impl AsContextMut, params: P) -> Result; + fn call_and_post_return( + &self, + store: impl AsContextMut, + params: P, + ) -> Result; } impl TypedFuncExt for TypedFunc where P: ComponentNamedList + Lower, - R: ComponentNamedList + Lift, + R: ComponentNamedList + Lift + Send + Sync + 'static, { - fn call_and_post_return(&self, mut store: impl AsContextMut, params: P) -> Result { + fn call_and_post_return( + &self, + mut store: impl AsContextMut, + params: P, + ) -> Result { let result = self.call(&mut store, params)?; self.post_return(&mut store)?; Ok(result) @@ -24,18 +32,18 @@ where } pub trait FuncExt { - fn call_and_post_return( + fn call_and_post_return( &self, - store: impl AsContextMut, + store: impl AsContextMut, params: &[Val], results: &mut [Val], ) -> Result<()>; } impl FuncExt for Func { - fn call_and_post_return( + fn call_and_post_return( &self, - mut store: impl AsContextMut, + mut store: impl AsContextMut, params: &[Val], results: &mut [Val], ) -> Result<()> { diff --git a/crates/test-programs/Cargo.toml b/crates/test-programs/Cargo.toml index ec734b766b03..dbd29a4ca445 100644 --- a/crates/test-programs/Cargo.toml +++ b/crates/test-programs/Cargo.toml @@ -15,6 +15,7 @@ anyhow = { workspace = true, features = ['std'] } wasi = "0.11.0" wasi-nn = "0.6.0" wit-bindgen = { workspace = true, features = ['default'] } +wit-bindgen-rt = { workspace = true, features = ['async'] } libc = { workspace = true } getrandom = "0.2.9" futures = { workspace = true, default-features = false, features = ['alloc'] } @@ -22,3 +23,6 @@ url = { workspace = true } sha2 = "0.10.2" base64 = "0.21.0" wasip2 = { version = "0.14.0", package = 'wasi' } +once_cell = "1.19.0" +flate2 = "1.0.28" + diff --git a/crates/test-programs/artifacts/Cargo.toml b/crates/test-programs/artifacts/Cargo.toml index 40cc3a7bdc91..33a56fbf467e 100644 --- a/crates/test-programs/artifacts/Cargo.toml +++ b/crates/test-programs/artifacts/Cargo.toml @@ -16,4 +16,5 @@ wasmtime = { workspace = true, features = ['incremental-cache', 'cranelift', 'co [build-dependencies] heck = { workspace = true } wit-component = { workspace = true } +wasmparser = { workspace = true, features = ['features'] } cargo_metadata = "0.18.1" diff --git a/crates/test-programs/artifacts/build.rs b/crates/test-programs/artifacts/build.rs index f5966a67ea5e..5c3f3300e2c8 100644 --- a/crates/test-programs/artifacts/build.rs +++ b/crates/test-programs/artifacts/build.rs @@ -4,6 +4,7 @@ use std::env; use std::fs; use std::path::{Path, PathBuf}; use std::process::Command; +use wasmparser::{Validator, WasmFeatures}; use wit_component::ComponentEncoder; fn main() { @@ -57,13 +58,13 @@ fn build_and_generate_tests() { let mut kinds = BTreeMap::new(); for target in targets { - let camel = target.to_shouty_snake_case(); + let shouty = target.to_shouty_snake_case(); let wasm = out_dir .join("wasm32-wasip1") .join("debug") .join(format!("{target}.wasm")); - generated_code += &format!("pub const {camel}: &'static str = {wasm:?};\n"); + generated_code += &format!("pub const {shouty}: &'static str = {wasm:?};\n"); // Bucket, based on the name of the test, into a "kind" which generates // a `foreach_*` macro below. @@ -78,6 +79,7 @@ fn build_and_generate_tests() { s if s.starts_with("dwarf_") => "dwarf", s if s.starts_with("config_") => "config", s if s.starts_with("keyvalue_") => "keyvalue", + s if s.starts_with("async_") => "async", // If you're reading this because you hit this panic, either add it // to a test suite above or add a new "suite". The purpose of the // categorization above is to have a static assertion that tests @@ -100,11 +102,12 @@ fn build_and_generate_tests() { } let adapter = match target.as_str() { "reactor" => &reactor_adapter, + s if s.starts_with("async_") => &reactor_adapter, s if s.starts_with("api_proxy") => &proxy_adapter, _ => &command_adapter, }; let path = compile_component(&wasm, adapter); - generated_code += &format!("pub const {camel}_COMPONENT: &'static str = {path:?};\n"); + generated_code += &format!("pub const {shouty}_COMPONENT: &'static str = {path:?};\n"); } for (kind, targets) in kinds { @@ -168,11 +171,18 @@ fn compile_component(wasm: &Path, adapter: &[u8]) -> PathBuf { let component = ComponentEncoder::default() .module(module.as_slice()) .unwrap() - .validate(true) + .validate(false) .adapter("wasi_snapshot_preview1", adapter) .unwrap() .encode() .expect("module can be translated to a component"); + + Validator::new_with_features( + WasmFeatures::WASM2 | WasmFeatures::COMPONENT_MODEL | WasmFeatures::COMPONENT_MODEL_ASYNC, + ) + .validate_all(&component) + .expect("component output should validate"); + let out_dir = wasm.parent().unwrap(); let stem = wasm.file_stem().unwrap().to_str().unwrap(); let component_path = out_dir.join(format!("{stem}.component.wasm")); diff --git a/crates/test-programs/src/bin/async_backpressure_callee.rs b/crates/test-programs/src/bin/async_backpressure_callee.rs new file mode 100644 index 000000000000..d4f031193e60 --- /dev/null +++ b/crates/test-programs/src/bin/async_backpressure_callee.rs @@ -0,0 +1,36 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "backpressure-callee", + async: { + exports: [ + "local:local/run#run" + ] + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::exports::local::local::{backpressure::Guest as Backpressure, run::Guest as Run}, + wit_bindgen_rt::async_support, +}; + +struct Component; + +impl Run for Component { + async fn run() { + // do nothing + } +} + +impl Backpressure for Component { + fn set_backpressure(enabled: bool) { + async_support::task_backpressure(enabled); + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_backpressure_caller.rs b/crates/test-programs/src/bin/async_backpressure_caller.rs new file mode 100644 index 000000000000..7ef6478be295 --- /dev/null +++ b/crates/test-programs/src/bin/async_backpressure_caller.rs @@ -0,0 +1,81 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "backpressure-caller", + async: { + imports: [ + "local:local/run#run" + ], + exports: [ + "local:local/run#run" + ] + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{ + exports::local::local::run::Guest, + local::local::{backpressure, run}, + }, + futures::future, + std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + }, +}; + +struct Component; + +impl Guest for Component { + async fn run() { + backpressure::set_backpressure(true); + + let mut a = Some(Box::pin(run::run())); + let mut b = Some(Box::pin(run::run())); + let mut c = Some(Box::pin(run::run())); + + let mut backpressure_is_set = true; + future::poll_fn(move |cx| { + let a_ready = is_ready(cx, &mut a); + let b_ready = is_ready(cx, &mut b); + let c_ready = is_ready(cx, &mut c); + + if backpressure_is_set { + assert!(!a_ready); + assert!(!b_ready); + assert!(!c_ready); + + backpressure::set_backpressure(false); + backpressure_is_set = false; + + Poll::Pending + } else if a_ready && b_ready && c_ready { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await + } +} + +fn is_ready(cx: &mut Context, fut: &mut Option>>>) -> bool { + if let Some(v) = fut.as_mut() { + if v.as_mut().poll(cx).is_ready() { + *fut = None; + true + } else { + false + } + } else { + true + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_borrowing_callee.rs b/crates/test-programs/src/bin/async_borrowing_callee.rs new file mode 100644 index 000000000000..9398ed9820b4 --- /dev/null +++ b/crates/test-programs/src/bin/async_borrowing_callee.rs @@ -0,0 +1,46 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "borrowing-callee", + async: { + exports: [ + "local:local/borrowing#foo", + "local:local/run-bool#run" + ] + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{ + exports::local::local::{borrowing::Guest as Borrowing, run_bool::Guest as RunBool}, + local::local::borrowing_types::X, + }, + wit_bindgen_rt::async_support, +}; + +struct Component; + +impl Borrowing for Component { + async fn foo(x: &X, misbehave: bool) { + let handle = x.handle(); + async_support::spawn(async move { + if misbehave { + unsafe { X::from_handle(handle) }.foo(); + } + }); + x.foo(); + } +} + +impl RunBool for Component { + async fn run(misbehave: bool) { + Self::foo(&X::new(), misbehave).await + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_borrowing_caller.rs b/crates/test-programs/src/bin/async_borrowing_caller.rs new file mode 100644 index 000000000000..a148e4188e04 --- /dev/null +++ b/crates/test-programs/src/bin/async_borrowing_caller.rs @@ -0,0 +1,33 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "borrowing-caller", + async: { + imports: [ + "local:local/borrowing#foo" + ], + exports: [ + "local:local/run-bool#run" + ] + } + }); + + use super::Component; + export!(Component); +} + +use bindings::{ + exports::local::local::run_bool::Guest, + local::local::{borrowing::foo, borrowing_types::X}, +}; + +struct Component; + +impl Guest for Component { + async fn run(misbehave: bool) { + foo(&X::new(), misbehave).await + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_error_context.rs b/crates/test-programs/src/bin/async_error_context.rs new file mode 100644 index 000000000000..5a5998b4624e --- /dev/null +++ b/crates/test-programs/src/bin/async_error_context.rs @@ -0,0 +1,29 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "error-context-usage", + async: { + exports: [ + "local:local/run#run", + ], + } + }); + + use super::Component; + export!(Component); +} +use bindings::exports::local::local::run::Guest; + +use wit_bindgen_rt::async_support::error_context_new; + +struct Component; + +impl Guest for Component { + async fn run() { + let err_ctx = error_context_new("error".into()); + _ = err_ctx.debug_message(); + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_error_context_callee.rs b/crates/test-programs/src/bin/async_error_context_callee.rs new file mode 100644 index 000000000000..d4aa10899353 --- /dev/null +++ b/crates/test-programs/src/bin/async_error_context_callee.rs @@ -0,0 +1,36 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "error-context-callee", + async: { + exports: [ + "local:local/run#run", + "local:local/run-result#run-pass", + "local:local/run-result#run-fail", + ], + } + }); + + use super::Component; + export!(Component); +} +use wit_bindgen_rt::async_support::{error_context_new, ErrorContext}; + +struct Component; + +impl bindings::exports::local::local::run_result::Guest for Component { + async fn run_fail() -> Result<(), ErrorContext> { + Err(error_context_new("error".into())) + } + + async fn run_pass() -> Result<(), ErrorContext> { + Ok(()) + } +} + +impl bindings::exports::local::local::run::Guest for Component { + async fn run() {} +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_error_context_caller.rs b/crates/test-programs/src/bin/async_error_context_caller.rs new file mode 100644 index 000000000000..57cf72c2d695 --- /dev/null +++ b/crates/test-programs/src/bin/async_error_context_caller.rs @@ -0,0 +1,32 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "error-context-caller", + async: { + imports: [ + "local:local/run-result#run-fail", + ], + exports: [ + "local:local/run#run", + ], + } + }); + + use super::Component; + export!(Component); +} +use bindings::exports::local::local::run::Guest; + +struct Component; + +impl Guest for Component { + async fn run() { + let Err(err_ctx) = bindings::local::local::run_result::run_fail().await else { + panic!("callee failure run should have produced an error"); + }; + _ = err_ctx.debug_message(); + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_http_echo.rs b/crates/test-programs/src/bin/async_http_echo.rs new file mode 100644 index 000000000000..90394a65e273 --- /dev/null +++ b/crates/test-programs/src/bin/async_http_echo.rs @@ -0,0 +1,68 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "wasi:http/proxy", + async: { + imports: [ + "wasi:http/types@0.3.0-draft#[static]body.finish", + "wasi:http/handler@0.3.0-draft#handle", + ], + exports: [ + "wasi:http/handler@0.3.0-draft#handle", + ] + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{ + exports::wasi::http::handler::Guest as Handler, + wasi::http::types::{Body, ErrorCode, Request, Response}, + wit_future, wit_stream, + }, + futures::{SinkExt, StreamExt}, + wit_bindgen_rt::async_support, +}; + +struct Component; + +impl Handler for Component { + /// Return a response which echoes the request headers, body, and trailers. + async fn handle(request: Request) -> Result { + let (headers, body) = Request::into_parts(request); + + if false { + // This is the easy and efficient way to do it... + Ok(Response::new(headers, body)) + } else { + // ...but we do it the more difficult, less efficient way here to exercise various component model + // features (e.g. `future`s, `stream`s, and post-return asynchronous execution): + let (trailers_tx, trailers_rx) = wit_future::new(); + let (mut pipe_tx, pipe_rx) = wit_stream::new(); + + async_support::spawn(async move { + let mut body_rx = body.stream().unwrap(); + while let Some(chunk) = body_rx.next().await { + pipe_tx.send(chunk).await.unwrap(); + } + + drop(pipe_tx); + + if let Some(trailers) = Body::finish(body).await.unwrap() { + trailers_tx.write(trailers).await; + } + }); + + Ok(Response::new( + headers, + Body::new(pipe_rx, Some(trailers_rx)), + )) + } + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_http_middleware.rs b/crates/test-programs/src/bin/async_http_middleware.rs new file mode 100644 index 000000000000..f65de7cbd3e2 --- /dev/null +++ b/crates/test-programs/src/bin/async_http_middleware.rs @@ -0,0 +1,161 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "wasi:http/proxy", + async: { + imports: [ + "wasi:http/types@0.3.0-draft#[static]body.finish", + "wasi:http/handler@0.3.0-draft#handle", + ], + exports: [ + "wasi:http/handler@0.3.0-draft#handle", + ] + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{ + exports::wasi::http::handler::Guest as Handler, + wasi::http::{ + handler, + types::{Body, ErrorCode, Headers, Request, Response}, + }, + wit_future, wit_stream, + }, + flate2::{ + write::{DeflateDecoder, DeflateEncoder}, + Compression, + }, + futures::{SinkExt, StreamExt}, + std::{io::Write, mem}, + wit_bindgen_rt::async_support, +}; + +struct Component; + +impl Handler for Component { + /// Forward the specified request to the imported `wasi:http/handler`, transparently decoding the request body + /// if it is `deflate`d and then encoding the response body if the client has provided an `accept-encoding: + /// deflate` header. + async fn handle(request: Request) -> Result { + // First, extract the parts of the request and check for (and remove) headers pertaining to body encodings. + let method = request.method(); + let scheme = request.scheme(); + let path_with_query = request.path_with_query(); + let authority = request.authority(); + let mut accept_deflated = false; + let mut content_deflated = false; + let (headers, body) = Request::into_parts(request); + let mut headers = headers.entries(); + headers.retain(|(k, v)| match (k.as_str(), v.as_slice()) { + ("accept-encoding", b"deflate") => { + accept_deflated = true; + false + } + ("content-encoding", b"deflate") => { + content_deflated = true; + false + } + _ => true, + }); + + let body = if content_deflated { + // Next, spawn a task to pipe and decode the original request body and trailers into a new request + // we'll create below. This will run concurrently with any code in the imported `wasi:http/handler`. + let (trailers_tx, trailers_rx) = wit_future::new(); + let (mut pipe_tx, pipe_rx) = wit_stream::new(); + + async_support::spawn(async move { + { + let mut body_rx = body.stream().unwrap(); + + let mut decoder = DeflateDecoder::new(Vec::new()); + + while let Some(chunk) = body_rx.next().await { + decoder.write_all(&chunk).unwrap(); + pipe_tx.send(mem::take(decoder.get_mut())).await.unwrap(); + } + + pipe_tx.send(decoder.finish().unwrap()).await.unwrap(); + + drop(pipe_tx); + } + + if let Some(trailers) = Body::finish(body).await.unwrap() { + trailers_tx.write(trailers).await; + } + }); + + Body::new(pipe_rx, Some(trailers_rx)) + } else { + body + }; + + // While the above task (if any) is running, synthesize a request from the parts collected above and pass + // it to the imported `wasi:http/handler`. + let my_request = Request::new(Headers::from_list(&headers).unwrap(), body, None); + my_request.set_method(&method).unwrap(); + my_request.set_scheme(scheme.as_ref()).unwrap(); + my_request + .set_path_with_query(path_with_query.as_deref()) + .unwrap(); + my_request.set_authority(authority.as_deref()).unwrap(); + + let response = handler::handle(my_request).await?; + + // Now that we have the response, extract the parts, adding an extra header if we'll be encoding the body. + let status_code = response.status_code(); + let (headers, body) = Response::into_parts(response); + let mut headers = headers.entries(); + if accept_deflated { + headers.push(("content-encoding".into(), b"deflate".into())); + } + + let body = if accept_deflated { + // Spawn another task; this one is to pipe and encode the original response body and trailers into a + // new response we'll create below. This will run concurrently with the caller's code (i.e. it won't + // necessarily complete before we return a value). + let (trailers_tx, trailers_rx) = wit_future::new(); + let (mut pipe_tx, pipe_rx) = wit_stream::new(); + + async_support::spawn(async move { + { + let mut body_rx = body.stream().unwrap(); + + let mut encoder = DeflateEncoder::new(Vec::new(), Compression::fast()); + + while let Some(chunk) = body_rx.next().await { + encoder.write_all(&chunk).unwrap(); + pipe_tx.send(mem::take(encoder.get_mut())).await.unwrap(); + } + + pipe_tx.send(encoder.finish().unwrap()).await.unwrap(); + + drop(pipe_tx); + } + + if let Some(trailers) = Body::finish(body).await.unwrap() { + trailers_tx.write(trailers).await; + } + }); + + Body::new(pipe_rx, Some(trailers_rx)) + } else { + body + }; + + // While the above tasks (if any) are running, synthesize a response from the parts collected above and + // return it. + let my_response = Response::new(Headers::from_list(&headers).unwrap(), body); + my_response.set_status_code(status_code).unwrap(); + + Ok(my_response) + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_poll.rs b/crates/test-programs/src/bin/async_poll.rs new file mode 100644 index 000000000000..8f7d8d75f588 --- /dev/null +++ b/crates/test-programs/src/bin/async_poll.rs @@ -0,0 +1,102 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "poll", + }); + + use super::Component; + export!(Component); +} + +use bindings::{exports::local::local::run::Guest, local::local::ready}; + +fn task_poll() -> Option<(i32, i32, i32)> { + #[cfg(not(target_arch = "wasm32"))] + { + unreachable!(); + } + + #[cfg(target_arch = "wasm32")] + { + #[link(wasm_import_module = "$root")] + unsafe extern "C" { + #[link_name = "[task-poll]"] + fn poll(_: *mut i32) -> i32; + } + let mut payload = [0i32; 3]; + if unsafe { poll(payload.as_mut_ptr()) } != 0 { + Some((payload[0], payload[1], payload[2])) + } else { + None + } + } +} + +fn async_when_ready() -> i32 { + #[cfg(not(target_arch = "wasm32"))] + { + unreachable!() + } + + #[cfg(target_arch = "wasm32")] + { + #[link(wasm_import_module = "local:local/ready")] + unsafe extern "C" { + #[link_name = "[async]when-ready"] + fn call_when_ready(_: *mut u8, _: *mut u8) -> i32; + } + unsafe { call_when_ready(std::ptr::null_mut(), std::ptr::null_mut()) } + } +} + +/// Call the `subtask.drop` canonical built-in function. +fn subtask_drop(subtask: u32) { + #[cfg(not(target_arch = "wasm32"))] + { + _ = subtask; + unreachable!(); + } + + #[cfg(target_arch = "wasm32")] + { + #[link(wasm_import_module = "$root")] + unsafe extern "C" { + #[link_name = "[subtask-drop]"] + fn subtask_drop(_: u32); + } + unsafe { + subtask_drop(subtask); + } + } +} + +struct Component; + +impl Guest for Component { + fn run() { + ready::set_ready(false); + + assert!(task_poll().is_none()); + + async_when_ready(); + + assert!(task_poll().is_none()); + + ready::set_ready(true); + + let Some((3, task, _)) = task_poll() else { + panic!() + }; + + subtask_drop(task as u32); + + assert!(task_poll().is_none()); + + assert!(async_when_ready() == 3 << 30); // STATUS_DONE + + assert!(task_poll().is_none()); + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_post_return_callee.rs b/crates/test-programs/src/bin/async_post_return_callee.rs new file mode 100644 index 000000000000..43f58d54f03e --- /dev/null +++ b/crates/test-programs/src/bin/async_post_return_callee.rs @@ -0,0 +1,78 @@ +// Here we avoid using wit-bindgen so that we can export our own post-return +// function and keep track of whether it was called. + +use std::{ + alloc::{self, Layout}, + mem::ManuallyDrop, + sync::Mutex, +}; + +static POST_RETURN_VALUE: Mutex> = Mutex::new(None); + +#[unsafe(export_name = "local:local/post-return#foo")] +unsafe extern "C" fn export_foo(ptr: *mut u8, len: usize) -> *mut u8 { + let result = alloc::alloc(Layout::from_size_align(8, 4).unwrap()); + *result.cast::<*mut u8>() = ptr; + *result.add(4).cast::() = len; + result +} + +#[unsafe(export_name = "cabi_post_local:local/post-return#foo")] +unsafe extern "C" fn export_post_return_foo(ptr: *mut u8) { + let s_ptr = *ptr.cast::<*mut u8>(); + let s_len = *ptr.add(4).cast::(); + alloc::dealloc(ptr, Layout::from_size_align(8, 4).unwrap()); + + *POST_RETURN_VALUE.lock().unwrap() = + Some(String::from_utf8(Vec::from_raw_parts(s_ptr, s_len, s_len)).unwrap()); +} + +#[unsafe(export_name = "local:local/post-return#get-post-return-value")] +unsafe extern "C" fn export_get_post_return_value() -> *mut u8 { + let s = ManuallyDrop::new(POST_RETURN_VALUE.lock().unwrap().take().unwrap()); + let result = alloc::alloc(Layout::from_size_align(8, 4).unwrap()); + *result.cast::<*mut u8>() = s.as_ptr().cast_mut(); + *result.add(4).cast::() = s.len(); + result +} + +#[unsafe(export_name = "cabi_post_local:local/post-return#get-post-return-value")] +unsafe extern "C" fn export_post_return_get_post_return_value(ptr: *mut u8) { + let s_ptr = *ptr.cast::<*mut u8>(); + let s_len = *ptr.add(4).cast::(); + alloc::dealloc(ptr, Layout::from_size_align(8, 4).unwrap()); + + drop(String::from_utf8(Vec::from_raw_parts(s_ptr, s_len, s_len)).unwrap()); +} + +#[cfg(target_arch = "wasm32")] +#[unsafe(link_section = "component-type:wit-bindgen:0.37.0:local:local:post-return-callee:encoded world")] +#[doc(hidden)] +#[allow( + clippy::octal_escapes, + reason = "this is a machine-generated binary blob" +)] +pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 255] = *b"\ +\0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07w\x01A\x02\x01A\x02\x01\ +B\x04\x01@\x01\x01ss\0s\x04\0\x03foo\x01\0\x01@\0\0s\x04\0\x15get-post-return-va\ +lue\x01\x01\x04\0\x17local:local/post-return\x05\0\x04\0\x1elocal:local/post-ret\ +urn-callee\x04\0\x0b\x18\x01\0\x12post-return-callee\x03\0\0\0G\x09producers\x01\ +\x0cprocessed-by\x02\x0dwit-component\x070.223.0\x10wit-bindgen-rust\x060.37.0"; + +/// # Safety +/// TODO +#[unsafe(export_name = "cabi_realloc")] +pub unsafe extern "C" fn cabi_realloc( + old_ptr: *mut u8, + old_len: usize, + align: usize, + new_size: usize, +) -> *mut u8 { + assert!(old_ptr.is_null()); + assert!(old_len == 0); + + alloc::alloc(Layout::from_size_align(new_size, align).unwrap()) +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_post_return_caller.rs b/crates/test-programs/src/bin/async_post_return_caller.rs new file mode 100644 index 000000000000..7e58e59e6668 --- /dev/null +++ b/crates/test-programs/src/bin/async_post_return_caller.rs @@ -0,0 +1,35 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "post-return-caller", + async: { + imports: [ + "local:local/post-return#foo" + ], + exports: [ + "local:local/run#run" + ] + } + }); + + use super::Component; + export!(Component); +} + +use bindings::{ + exports::local::local::run::Guest, + local::local::post_return::{foo, get_post_return_value}, +}; + +struct Component; + +impl Guest for Component { + async fn run() { + let s = "All mimsy were the borogoves"; + assert_eq!(s, &foo(s).await); + assert_eq!(s, &get_post_return_value()); + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_round_trip_direct_stackless.rs b/crates/test-programs/src/bin/async_round_trip_direct_stackless.rs new file mode 100644 index 000000000000..fd954b7aab2a --- /dev/null +++ b/crates/test-programs/src/bin/async_round_trip_direct_stackless.rs @@ -0,0 +1,24 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "round-trip-direct", + async: true, + }); + + use super::Component; + export!(Component); +} + +struct Component; + +impl bindings::Guest for Component { + async fn foo(s: String) -> String { + format!( + "{} - exited guest", + bindings::foo(&format!("{s} - entered guest")).await + ) + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_round_trip_stackful.rs b/crates/test-programs/src/bin/async_round_trip_stackful.rs new file mode 100644 index 000000000000..d2747bc25cc0 --- /dev/null +++ b/crates/test-programs/src/bin/async_round_trip_stackful.rs @@ -0,0 +1,150 @@ +// This tests callback-less (AKA stackful) async exports. +// +// Testing this case using Rust's LLVM-based toolchain is tricky because, as of +// this writing, LLVM does not produce reentrance-safe code. Specifically, it +// allocates a single shadow stack for use whenever a program needs to take the +// address of a stack variable, which makes concurrent execution of multiple +// Wasm stacks in the same instance hazardous. +// +// Given the above, we write code directly against the component model ABI +// rather than use `wit-bindgen`, and we carefully avoid use of the shadow stack +// across yield points such as calls to `task.wait` in order to keep the code +// reentrant. + +use std::alloc::{self, Layout}; + +#[cfg(target_arch = "wasm32")] +#[link(wasm_import_module = "[export]local:local/baz")] +unsafe extern "C" { + #[link_name = "[task-return]foo"] + fn task_return_foo(ptr: *mut u8, len: usize); +} +#[cfg(not(target_arch = "wasm32"))] +unsafe extern "C" fn task_return_foo(_ptr: *mut u8, _len: usize) { + unreachable!() +} + +#[cfg(target_arch = "wasm32")] +#[link(wasm_import_module = "local:local/baz")] +unsafe extern "C" { + #[link_name = "[async]foo"] + fn import_foo(params: *mut u8, results: *mut u8) -> u32; +} +#[cfg(not(target_arch = "wasm32"))] +unsafe extern "C" fn import_foo(_params: *mut u8, _results: *mut u8) -> u32 { + unreachable!() +} + +#[cfg(target_arch = "wasm32")] +#[link(wasm_import_module = "$root")] +unsafe extern "C" { + #[link_name = "[task-wait]"] + fn task_wait(results: *mut i32) -> i32; +} +#[cfg(not(target_arch = "wasm32"))] +unsafe extern "C" fn task_wait(_results: *mut i32) -> i32 { + unreachable!() +} + +#[cfg(target_arch = "wasm32")] +#[link(wasm_import_module = "$root")] +unsafe extern "C" { + #[link_name = "[subtask-drop]"] + fn subtask_drop(task: u32); +} +#[cfg(not(target_arch = "wasm32"))] +unsafe extern "C" fn subtask_drop(_task: u32) { + unreachable!() +} + +const _STATUS_STARTING: u32 = 0; +const _STATUS_STARTED: u32 = 1; +const _STATUS_RETURNED: u32 = 2; +const STATUS_DONE: u32 = 3; + +const _EVENT_CALL_STARTING: i32 = 0; +const _EVENT_CALL_STARTED: i32 = 1; +const _EVENT_CALL_RETURNED: i32 = 2; +const EVENT_CALL_DONE: i32 = 3; + +#[unsafe(export_name = "[async-stackful]local:local/baz#foo")] +unsafe extern "C" fn export_foo(ptr: *mut u8, len: usize) { + // Note that we're careful not to take the address of any stack-allocated + // value here. We need to avoid relying on the LLVM-generated shadow stack + // in order to correctly support reentrancy. It's okay to call functions + // which use the shadow stack, as long as they pop everything off before we + // reach a yield point such as a call to `task.wait`. + + let s = format!( + "{} - entered guest", + String::from_utf8(Vec::from_raw_parts(ptr, len, len)).unwrap() + ); + + let layout = Layout::from_size_align(8, 4).unwrap(); + + let params = alloc::alloc(layout); + *params.cast::<*mut u8>() = s.as_ptr().cast_mut(); + *params.add(4).cast::() = s.len(); + + let results = alloc::alloc(layout); + + let result = import_foo(params, results); + let mut status = result >> 30; + let call = result & !(0b11 << 30); + while status != STATUS_DONE { + // Note the use of `Box` here to avoid taking the address of a stack + // allocation. + let payload = Box::into_raw(Box::new([0i32; 2])); + let event = task_wait(payload.cast()); + let payload = Box::from_raw(payload); + if event == EVENT_CALL_DONE { + assert!(call == payload[0] as u32); + subtask_drop(call); + status = STATUS_DONE; + } + } + alloc::dealloc(params, layout); + + let len = *results.add(4).cast::(); + let s = format!( + "{} - exited guest", + String::from_utf8(Vec::from_raw_parts(*results.cast::<*mut u8>(), len, len)).unwrap() + ); + alloc::dealloc(results, layout); + + task_return_foo(s.as_ptr().cast_mut(), s.len()); +} + +// Copied from `wit-bindgen`-generated output +#[cfg(target_arch = "wasm32")] +#[unsafe(link_section = "component-type:wit-bindgen:0.35.0:local:local:round-trip:encoded world")] +#[doc(hidden)] +#[allow( + clippy::octal_escapes, + reason = "this is a machine-generated binary blob" +)] +pub static __WIT_BINDGEN_COMPONENT_TYPE: [u8; 239] = *b"\ +\0asm\x0d\0\x01\0\0\x19\x16wit-component-encoding\x04\0\x07o\x01A\x02\x01A\x04\x01\ +B\x02\x01@\x01\x01ss\0s\x04\0\x03foo\x01\0\x03\0\x0flocal:local/baz\x05\0\x01B\x02\ +\x01@\x01\x01ss\0s\x04\0\x03foo\x01\0\x04\0\x0flocal:local/baz\x05\x01\x04\0\x16\ +local:local/round-trip\x04\0\x0b\x10\x01\0\x0around-trip\x03\0\0\0G\x09producers\ +\x01\x0cprocessed-by\x02\x0dwit-component\x070.220.0\x10wit-bindgen-rust\x060.35\ +.0"; + +/// # Safety +/// TODO +#[unsafe(export_name = "cabi_realloc")] +pub unsafe extern "C" fn cabi_realloc( + old_ptr: *mut u8, + old_len: usize, + align: usize, + new_size: usize, +) -> *mut u8 { + assert!(old_ptr.is_null()); + assert!(old_len == 0); + + alloc::alloc(Layout::from_size_align(new_size, align).unwrap()) +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_round_trip_stackless.rs b/crates/test-programs/src/bin/async_round_trip_stackless.rs new file mode 100644 index 000000000000..f06bf95571c2 --- /dev/null +++ b/crates/test-programs/src/bin/async_round_trip_stackless.rs @@ -0,0 +1,26 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "round-trip", + async: true, + }); + + use super::Component; + export!(Component); +} + +use bindings::{exports::local::local::baz::Guest as Baz, local::local::baz}; + +struct Component; + +impl Baz for Component { + async fn foo(s: String) -> String { + format!( + "{} - exited guest", + baz::foo(&format!("{s} - entered guest")).await + ) + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_round_trip_synchronous.rs b/crates/test-programs/src/bin/async_round_trip_synchronous.rs new file mode 100644 index 000000000000..bcf4ccae2104 --- /dev/null +++ b/crates/test-programs/src/bin/async_round_trip_synchronous.rs @@ -0,0 +1,25 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "round-trip", + }); + + use super::Component; + export!(Component); +} + +use bindings::{exports::local::local::baz::Guest as Baz, local::local::baz}; + +struct Component; + +impl Baz for Component { + fn foo(s: String) -> String { + format!( + "{} - exited guest", + baz::foo(&format!("{s} - entered guest")) + ) + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_round_trip_wait.rs b/crates/test-programs/src/bin/async_round_trip_wait.rs new file mode 100644 index 000000000000..6f3b3ced7fc4 --- /dev/null +++ b/crates/test-programs/src/bin/async_round_trip_wait.rs @@ -0,0 +1,35 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "round-trip", + async: { + imports: [ + "local:local/baz#foo", + ] + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{exports::local::local::baz::Guest as Baz, local::local::baz}, + wit_bindgen_rt::async_support, +}; + +struct Component; + +impl Baz for Component { + fn foo(s: String) -> String { + async_support::block_on(async move { + format!( + "{} - exited guest", + baz::foo(&format!("{s} - entered guest")).await + ) + }) + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_transmit_callee.rs b/crates/test-programs/src/bin/async_transmit_callee.rs new file mode 100644 index 000000000000..b1345e53b5a1 --- /dev/null +++ b/crates/test-programs/src/bin/async_transmit_callee.rs @@ -0,0 +1,77 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "transmit-callee", + async: { + exports: [ + "local:local/transmit#exchange", + ], + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{ + exports::local::local::transmit::{Control, Guest}, + wit_future, wit_stream, + }, + futures::{SinkExt, StreamExt}, + std::future::IntoFuture, + wit_bindgen_rt::async_support::{self, FutureReader, StreamReader}, +}; + +struct Component; + +impl Guest for Component { + async fn exchange( + mut control_rx: StreamReader, + mut caller_stream_rx: StreamReader, + caller_future_rx1: FutureReader, + caller_future_rx2: FutureReader, + ) -> ( + StreamReader, + FutureReader, + FutureReader, + ) { + let (mut callee_stream_tx, callee_stream_rx) = wit_stream::new(); + let (callee_future_tx1, callee_future_rx1) = wit_future::new(); + let (callee_future_tx2, callee_future_rx2) = wit_future::new(); + + async_support::spawn(async move { + let mut caller_future_rx1 = Some(caller_future_rx1); + let mut callee_future_tx1 = Some(callee_future_tx1); + + while let Some(messages) = control_rx.next().await { + for message in messages { + match message { + Control::ReadStream(value) => { + assert_eq!(caller_stream_rx.next().await, Some(vec![value])); + } + Control::ReadFuture(value) => { + assert_eq!( + caller_future_rx1.take().unwrap().into_future().await, + Some(value) + ); + } + Control::WriteStream(value) => { + callee_stream_tx.send(vec![value]).await.unwrap(); + } + Control::WriteFuture(value) => { + callee_future_tx1.take().unwrap().write(value).await; + } + } + } + } + + drop((caller_future_rx2, callee_future_tx2)); + }); + + (callee_stream_rx, callee_future_rx1, callee_future_rx2) + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_transmit_caller.rs b/crates/test-programs/src/bin/async_transmit_caller.rs new file mode 100644 index 000000000000..2612ba057030 --- /dev/null +++ b/crates/test-programs/src/bin/async_transmit_caller.rs @@ -0,0 +1,166 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "transmit-caller", + async: { + imports: [ + "local:local/transmit#exchange", + ], + exports: [ + "local:local/run#run", + ], + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{ + exports::local::local::run::Guest, + local::local::transmit::{self, Control}, + wit_future, wit_stream, + }, + futures::{future, FutureExt, SinkExt, StreamExt}, + std::{ + future::{Future, IntoFuture}, + pin::pin, + task::Poll, + }, +}; + +struct Component; + +impl Guest for Component { + async fn run() { + let (mut control_tx, control_rx) = wit_stream::new(); + let (mut caller_stream_tx, caller_stream_rx) = wit_stream::new(); + let (mut caller_future_tx1, caller_future_rx1) = wit_future::new(); + let (caller_future_tx2, caller_future_rx2) = wit_future::new(); + + let (mut callee_stream_rx, mut callee_future_rx1, callee_future_rx2) = transmit::exchange( + control_rx, + caller_stream_rx, + caller_future_rx1, + caller_future_rx2, + ) + .await; + + // Tell peer to read from its end of the stream and assert that the result matches an expected value. + control_tx + .send(vec![Control::ReadStream("a".into())]) + .await + .unwrap(); + caller_stream_tx.send(vec!["a".into()]).await.unwrap(); + + // Start writing another value, but cancel the write before telling the peer to read. + { + let send = caller_stream_tx.send(vec!["b".into()]); + assert!(poll(send).await.is_err()); + caller_stream_tx.cancel(); + } + + // Tell the peer to read an expected value again, which should _not_ match the value provided in the + // canceled write above. + control_tx + .send(vec![Control::ReadStream("c".into())]) + .await + .unwrap(); + caller_stream_tx.send(vec!["c".into()]).await.unwrap(); + + // Start writing a value to the future, but cancel the write before telling the peer to read. + { + let send = caller_future_tx1.write("x".into()); + match poll(send).await { + Ok(_) => panic!(), + Err(send) => caller_future_tx1 = send.cancel(), + } + } + + // Tell the peer to read an expected value again, which should _not_ match the value provided in the + // canceled write above. + control_tx + .send(vec![Control::ReadFuture("y".into())]) + .await + .unwrap(); + caller_future_tx1.write("y".into()).await; + + // Tell the peer to write a value to its end of the stream, then read from our end and assert the value + // matches. + control_tx + .send(vec![Control::WriteStream("a".into())]) + .await + .unwrap(); + assert_eq!(callee_stream_rx.next().await, Some(vec!["a".into()])); + + // Start reading a value from the stream, but cancel the read before telling the peer to write. + { + let next = callee_stream_rx.next(); + assert!(poll(next).await.is_err()); + callee_stream_rx.cancel(); + } + + // Once again, tell the peer to write a value to its end of the stream, then read from our end and assert + // the value matches. + control_tx + .send(vec![Control::WriteStream("b".into())]) + .await + .unwrap(); + assert_eq!(callee_stream_rx.next().await, Some(vec!["b".into()])); + + // Start reading a value from the future, but cancel the read before telling the peer to write. + { + let next = callee_future_rx1.into_future(); + match poll(next).await { + Ok(_) => panic!(), + Err(next) => callee_future_rx1 = next.cancel(), + } + } + + // Tell the peer to write a value to its end of the future, then read from our end and assert the value + // matches. + control_tx + .send(vec![Control::WriteFuture("b".into())]) + .await + .unwrap(); + assert_eq!(callee_future_rx1.into_future().await, Some("b".into())); + + // Start writing a value to the stream, but drop the stream without telling the peer to read. + let send = caller_stream_tx.send(vec!["d".into()]); + assert!(poll(send).await.is_err()); + drop(caller_stream_tx); + + // Start reading a value from the stream, but drop the stream without telling the peer to write. + let next = callee_stream_rx.next(); + assert!(poll(next).await.is_err()); + drop(callee_stream_rx); + + // Start writing a value to the future, but drop the write without telling the peer to read. + { + let send = pin!(caller_future_tx2.write("x".into())); + assert!(poll(send).await.is_err()); + } + + // Start reading a value from the future, but drop the read without telling the peer to write. + { + let next = callee_future_rx2.into_future(); + assert!(poll(next).await.is_err()); + } + } +} + +async fn poll + Unpin>(fut: F) -> Result { + let mut fut = Some(fut); + future::poll_fn(move |cx| { + let mut fut = fut.take().unwrap(); + Poll::Ready(match fut.poll_unpin(cx) { + Poll::Ready(v) => Ok(v), + Poll::Pending => Err(fut), + }) + }) + .await +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_unit_stream_callee.rs b/crates/test-programs/src/bin/async_unit_stream_callee.rs new file mode 100644 index 000000000000..ffd9ed380f51 --- /dev/null +++ b/crates/test-programs/src/bin/async_unit_stream_callee.rs @@ -0,0 +1,46 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "unit-stream-callee", + async: { + exports: [ + "local:local/unit-stream#run", + ], + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{exports::local::local::unit_stream::Guest, wit_stream}, + futures::SinkExt, + wit_bindgen_rt::async_support::{self, StreamReader}, +}; + +struct Component; + +impl Guest for Component { + async fn run(count: u32) -> StreamReader<()> { + let (mut tx, rx) = wit_stream::new(); + + async_support::spawn(async move { + let mut sent = 0; + let mut chunk_size = 1; + while sent < count { + let n = (count - sent).min(chunk_size); + tx.send(vec![(); usize::try_from(n).unwrap()]) + .await + .unwrap(); + sent += n; + chunk_size *= 2; + } + }); + + rx + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_unit_stream_caller.rs b/crates/test-programs/src/bin/async_unit_stream_caller.rs new file mode 100644 index 000000000000..878ea225bbd3 --- /dev/null +++ b/crates/test-programs/src/bin/async_unit_stream_caller.rs @@ -0,0 +1,41 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "unit-stream-caller", + async: { + imports: [ + "local:local/unit-stream#run", + ], + exports: [ + "local:local/run#run", + ], + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{exports::local::local::run::Guest, local::local::unit_stream}, + futures::StreamExt, +}; + +struct Component; + +impl Guest for Component { + async fn run() { + let count = 42; + let mut rx = unit_stream::run(count).await; + + let mut received = 0; + while let Some(chunk) = rx.next().await { + received += chunk.len(); + } + + assert_eq!(count, u32::try_from(received).unwrap()); + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_yield_callee.rs b/crates/test-programs/src/bin/async_yield_callee.rs new file mode 100644 index 000000000000..4274546ce3dd --- /dev/null +++ b/crates/test-programs/src/bin/async_yield_callee.rs @@ -0,0 +1,27 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "yield-callee", + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{exports::local::local::run::Guest, local::local::continue_}, + wit_bindgen_rt::async_support, +}; + +struct Component; + +impl Guest for Component { + fn run() { + while continue_::get_continue() { + async_support::task_yield(); + } + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/test-programs/src/bin/async_yield_caller.rs b/crates/test-programs/src/bin/async_yield_caller.rs new file mode 100644 index 000000000000..3cdd13ade127 --- /dev/null +++ b/crates/test-programs/src/bin/async_yield_caller.rs @@ -0,0 +1,62 @@ +mod bindings { + wit_bindgen::generate!({ + path: "../misc/component-async-tests/wit", + world: "yield-caller", + async: { + imports: [ + "local:local/ready#when-ready", + "local:local/run#run", + ], + exports: [ + "local:local/run#run", + ], + } + }); + + use super::Component; + export!(Component); +} + +use { + bindings::{ + exports::local::local::run::Guest, + local::local::{continue_, ready, run}, + }, + futures::future, + std::{future::Future, task::Poll}, +}; + +struct Component; + +impl Guest for Component { + async fn run() { + ready::set_ready(false); + continue_::set_continue(true); + + let mut ready = Some(Box::pin(ready::when_ready())); + let mut run = Some(Box::pin(run::run())); + future::poll_fn(move |cx| { + let ready_poll = ready.as_mut().map(|v| v.as_mut().poll(cx)); + ready::set_ready(true); + let run_poll = run.as_mut().map(|v| v.as_mut().poll(cx)); + + match (run_poll, ready_poll) { + (None | Some(Poll::Ready(())), None | Some(Poll::Ready(()))) => { + return Poll::Ready(()); + } + (Some(Poll::Ready(())), _) => run = None, + (_, Some(Poll::Ready(()))) => { + ready = None; + continue_::set_continue(false); + } + _ => {} + } + + Poll::Pending + }) + .await + } +} + +// Unused function; required since this file is built as a `bin`: +fn main() {} diff --git a/crates/wasi-config/Cargo.toml b/crates/wasi-config/Cargo.toml index 81bd61ef7184..cadaf77bb7a2 100644 --- a/crates/wasi-config/Cargo.toml +++ b/crates/wasi-config/Cargo.toml @@ -13,7 +13,7 @@ workspace = true [dependencies] anyhow = { workspace = true } -wasmtime = { workspace = true, features = ["runtime", "component-model"] } +wasmtime = { workspace = true, features = ["runtime", "component-model", "async"] } [dev-dependencies] test-programs-artifacts = { workspace = true } diff --git a/crates/wasmtime/src/config.rs b/crates/wasmtime/src/config.rs index 9ab3aca4dabf..3c0c7629ebe0 100644 --- a/crates/wasmtime/src/config.rs +++ b/crates/wasmtime/src/config.rs @@ -1114,8 +1114,6 @@ impl Config { /// lifting and lowering functions, as well as `stream`, `future`, and /// `error-context` types. /// - /// Please note that Wasmtime's support for this feature is _very_ incomplete. - /// /// [proposal]: https://github.com/WebAssembly/component-model/blob/main/design/mvp/Async.md #[cfg(feature = "component-model-async")] pub fn wasm_component_model_async(&mut self, enable: bool) -> &mut Self { diff --git a/crates/wasmtime/src/runtime/component/concurrent.rs b/crates/wasmtime/src/runtime/component/concurrent.rs index 1d05f0ecb2c3..ef2f99fdcb2a 100644 --- a/crates/wasmtime/src/runtime/component/concurrent.rs +++ b/crates/wasmtime/src/runtime/component/concurrent.rs @@ -1,18 +1,88 @@ use { crate::{ + component::func::{self, Func, Lower as _, LowerContext, Options}, store::StoreInner, - vm::{VMFuncRef, VMMemoryDefinition}, - AsContextMut, ValRaw, + vm::{ + component::{ + CallContext, ComponentInstance, InstanceFlags, ResourceTables, WaitableState, + }, + mpk::{self, ProtectionMask}, + AsyncWasmCallState, PreviousAsyncWasmCallState, SendSyncPtr, VMFuncRef, + VMMemoryDefinition, VMStore, + }, + AsContextMut, Engine, StoreContextMut, ValRaw, }, - anyhow::Result, - futures::{stream::FuturesUnordered, FutureExt}, - std::{boxed::Box, future::Future, pin::Pin}, - wasmtime_environ::component::{RuntimeComponentInstanceIndex, TypeTupleIndex}, + anyhow::{anyhow, bail, Context as _, Result}, + futures::{ + channel::oneshot, + future::{self, Either, FutureExt}, + stream::{FuturesUnordered, StreamExt}, + }, + once_cell::sync::Lazy, + ready_chunks::ReadyChunks, + std::{ + any::Any, + borrow::ToOwned, + boxed::Box, + cell::UnsafeCell, + collections::{HashMap, HashSet, VecDeque}, + future::Future, + marker::PhantomData, + mem::{self, MaybeUninit}, + ops::Range, + pin::{pin, Pin}, + ptr::{self, NonNull}, + sync::{Arc, Mutex}, + task::{Context, Poll, Wake, Waker}, + vec::Vec, + }, + table::{Table, TableId}, + wasmtime_environ::component::{ + InterfaceType, RuntimeComponentInstanceIndex, StringEncoding, + TypeComponentLocalErrorContextTableIndex, TypeFutureTableIndex, TypeStreamTableIndex, + TypeTupleIndex, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + }, + wasmtime_fiber::{Fiber, Suspend}, }; -pub use futures_and_streams::{ErrorContext, FutureReader, StreamReader}; +pub use futures_and_streams::{ + future, stream, ErrorContext, FutureReader, FutureWriter, StreamReader, StreamWriter, +}; +use futures_and_streams::{FlatAbi, TableIndex, TransmitState}; mod futures_and_streams; +mod ready_chunks; +mod table; + +// TODO: The handling of `task.yield` and `task.backpressure` was bolted on late in the implementation and is +// currently haphazard. We need a refactor to manage yielding, backpressure, and event polling and delivery in a +// more unified and structured way. + +#[derive(Clone, Copy, Eq, PartialEq)] +#[repr(u32)] +enum Status { + Starting, + Started, + Returned, + Done, +} + +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +#[repr(u32)] +enum Event { + _Starting, + Started, + Returned, + Done, + _Yielded, + StreamRead, + StreamWrite, + FutureRead, + FutureWrite, +} + +const EXIT_FLAG_ASYNC_CALLER: u32 = 1 << 0; +const EXIT_FLAG_ASYNC_CALLEE: u32 = 1 << 1; /// Represents the result of a concurrent operation. /// @@ -34,9 +104,8 @@ impl Promise { /// The returned future will require exclusive use of the store until it /// completes. If you need to await more than one `Promise` concurrently, /// use [`PromisesUnordered`]. - pub async fn get(self, store: impl AsContextMut) -> Result { - _ = store; - todo!() + pub async fn get(self, mut store: impl AsContextMut) -> Result { + Ok(poll_until(store.as_context_mut(), self.0).await?.1) } /// Convert this `Promise` to a future which may be `await`ed for its @@ -72,9 +141,11 @@ impl PromisesUnordered { } /// Get the next result from this collection, if any. - pub async fn next(&mut self, store: impl AsContextMut) -> Result> { - _ = store; - todo!() + pub async fn next( + &mut self, + mut store: impl AsContextMut, + ) -> Result> { + Ok(poll_until(store.as_context_mut(), self.0.next()).await?.1) } } @@ -91,6 +162,7 @@ pub unsafe trait VMComponentAsyncStore { /// The `task.return` intrinsic. fn task_return( &mut self, + instance: &mut ComponentInstance, ty: TypeTupleIndex, storage: *mut ValRaw, storage_len: usize, @@ -99,6 +171,7 @@ pub unsafe trait VMComponentAsyncStore { /// The `task.wait` intrinsic. fn task_wait( &mut self, + instance: &mut ComponentInstance, caller_instance: RuntimeComponentInstanceIndex, async_: bool, memory: *mut VMMemoryDefinition, @@ -108,6 +181,7 @@ pub unsafe trait VMComponentAsyncStore { /// The `task.poll` intrinsic. fn task_poll( &mut self, + instance: &mut ComponentInstance, caller_instance: RuntimeComponentInstanceIndex, async_: bool, memory: *mut VMMemoryDefinition, @@ -115,11 +189,12 @@ pub unsafe trait VMComponentAsyncStore { ) -> Result; /// The `task.yield` intrinsic. - fn task_yield(&mut self, async_: bool) -> Result<()>; + fn task_yield(&mut self, instance: &mut ComponentInstance, async_: bool) -> Result<()>; /// The `subtask.drop` intrinsic. fn subtask_drop( &mut self, + instance: &mut ComponentInstance, caller_instance: RuntimeComponentInstanceIndex, task_id: u32, ) -> Result<()>; @@ -140,6 +215,7 @@ pub unsafe trait VMComponentAsyncStore { /// both of the functions involved are async functions. fn async_exit( &mut self, + instance: &mut ComponentInstance, callback: *mut VMFuncRef, post_return: *mut VMFuncRef, caller_instance: RuntimeComponentInstanceIndex, @@ -149,6 +225,202 @@ pub unsafe trait VMComponentAsyncStore { result_count: u32, flags: u32, ) -> Result; + + /// The `future.new` intrinsic. + fn future_new( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + ) -> Result; + + /// The `future.write` intrinsic. + fn future_write( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeFutureTableIndex, + future: u32, + address: u32, + ) -> Result; + + /// The `future.read` intrinsic. + fn future_read( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeFutureTableIndex, + future: u32, + address: u32, + ) -> Result; + + /// The `future.cancel-write` intrinsic. + fn future_cancel_write( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + async_: bool, + writer: u32, + ) -> Result; + + /// The `future.cancel-read` intrinsic. + fn future_cancel_read( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + async_: bool, + reader: u32, + ) -> Result; + + /// The `future.close-writable` intrinsic. + fn future_close_writable( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + writer: u32, + error: u32, + ) -> Result<()>; + + /// The `future.close-readable` intrinsic. + fn future_close_readable( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + reader: u32, + ) -> Result<()>; + + /// The `stream.new` intrinsic. + fn stream_new( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + ) -> Result; + + /// The `stream.write` intrinsic. + fn stream_write( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeStreamTableIndex, + stream: u32, + address: u32, + count: u32, + ) -> Result; + + /// The `stream.read` intrinsic. + fn stream_read( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeStreamTableIndex, + stream: u32, + address: u32, + count: u32, + ) -> Result; + + /// The `stream.cancel-write` intrinsic. + fn stream_cancel_write( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + async_: bool, + writer: u32, + ) -> Result; + + /// The `stream.cancel-read` intrinsic. + fn stream_cancel_read( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + async_: bool, + reader: u32, + ) -> Result; + + /// The `stream.close-writable` intrinsic. + fn stream_close_writable( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + writer: u32, + error: u32, + ) -> Result<()>; + + /// The `stream.close-readable` intrinsic. + fn stream_close_readable( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + reader: u32, + ) -> Result<()>; + + /// The "fast-path" implementation of the `stream.write` intrinsic for + /// "flat" (i.e. memcpy-able) payloads. + fn flat_stream_write( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + ty: TypeStreamTableIndex, + payload_size: u32, + payload_align: u32, + stream: u32, + address: u32, + count: u32, + ) -> Result; + + /// The "fast-path" implementation of the `stream.read` intrinsic for "flat" + /// (i.e. memcpy-able) payloads. + fn flat_stream_read( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + ty: TypeStreamTableIndex, + payload_size: u32, + payload_align: u32, + stream: u32, + address: u32, + count: u32, + ) -> Result; + + /// The `error-context.new` intrinsic. + fn error_context_new( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeComponentLocalErrorContextTableIndex, + debug_msg_address: u32, + debug_msg_len: u32, + ) -> Result; + + /// The `error-context.debug-message` intrinsic. + fn error_context_debug_message( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeComponentLocalErrorContextTableIndex, + err_ctx_handle: u32, + debug_msg_address: u32, + ) -> Result<()>; + + /// The `error-context.drop` intrinsic. + fn error_context_drop( + &mut self, + instance: &mut ComponentInstance, + ty: TypeComponentLocalErrorContextTableIndex, + err_ctx_handle: u32, + ) -> Result<()>; } unsafe impl VMComponentAsyncStore for StoreInner { @@ -157,54 +429,141 @@ unsafe impl VMComponentAsyncStore for StoreInner { caller_instance: RuntimeComponentInstanceIndex, enabled: u32, ) -> Result<()> { - _ = (caller_instance, enabled); - todo!() + let mut cx = StoreContextMut(self); + let entry = cx + .concurrent_state() + .instance_states + .entry(caller_instance) + .or_default(); + let old = entry.backpressure; + let new = enabled != 0; + entry.backpressure = new; + + if old && !new && !entry.task_queue.is_empty() { + cx.concurrent_state().unblocked.insert(caller_instance); + } + + Ok(()) } fn task_return( &mut self, + instance: &mut ComponentInstance, ty: TypeTupleIndex, storage: *mut ValRaw, storage_len: usize, ) -> Result<()> { - _ = (ty, storage, storage_len); - todo!() + let storage = unsafe { std::slice::from_raw_parts(storage, storage_len) }; + let mut cx = StoreContextMut(self); + let guest_task = cx.concurrent_state().guest_task.unwrap(); + let (lift, lift_ty) = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .lift_result + .take() + .ok_or_else(|| anyhow!("`task.return` called more than once"))?; + + if ty != lift_ty { + bail!("invalid `task.return` signature for current task"); + } + + assert!(cx + .concurrent_state() + .table + .get(guest_task)? + .result + .is_none()); + + log::trace!("task.return for {}", guest_task.rep()); + + let cx = cx.0.traitobj().as_ptr(); + let result = lift(cx, storage)?; + let mut cx = unsafe { StoreContextMut::(&mut *cx.cast()) }; + + let (calls, host_table, _) = cx.0.component_resource_state(); + ResourceTables { + calls, + host_table: Some(host_table), + tables: Some((*instance).component_resource_tables()), + } + .exit_call()?; + + if let Caller::Host(tx) = &mut cx.concurrent_state().table.get_mut(guest_task)?.caller { + _ = tx.take().unwrap().send(result); + } else { + cx.concurrent_state().table.get_mut(guest_task)?.result = Some(result); + } + + Ok(()) } fn task_wait( &mut self, + instance: &mut ComponentInstance, caller_instance: RuntimeComponentInstanceIndex, async_: bool, memory: *mut VMMemoryDefinition, payload: u32, ) -> Result { - _ = (caller_instance, async_, memory, payload); - todo!() + task_check( + StoreContextMut(self), + instance, + async_, + TaskCheck::Wait(memory, payload, caller_instance), + ) } fn task_poll( &mut self, + instance: &mut ComponentInstance, caller_instance: RuntimeComponentInstanceIndex, async_: bool, memory: *mut VMMemoryDefinition, payload: u32, ) -> Result { - _ = (caller_instance, async_, memory, payload); - todo!() + task_check( + StoreContextMut(self), + instance, + async_, + TaskCheck::Poll(memory, payload, caller_instance), + ) } - fn task_yield(&mut self, async_: bool) -> Result<()> { - _ = async_; - todo!() + fn task_yield(&mut self, instance: &mut ComponentInstance, async_: bool) -> Result<()> { + task_check(StoreContextMut(self), instance, async_, TaskCheck::Yield).map(drop) } fn subtask_drop( &mut self, + instance: &mut ComponentInstance, caller_instance: RuntimeComponentInstanceIndex, task_id: u32, ) -> Result<()> { - _ = (caller_instance, task_id); - todo!() + let mut cx = StoreContextMut(self); + let (rep, WaitableState::Task) = + instance.component_waitable_tables()[caller_instance].remove_by_index(task_id)? + else { + bail!("invalid task handle: {task_id}"); + }; + let table = &mut cx.concurrent_state().table; + log::trace!("subtask_drop delete {rep}"); + let task = table.delete_any(rep)?; + let expected_caller_instance = match task.downcast::() { + Ok(task) => task.caller_instance, + Err(task) => match task.downcast::() { + Ok(task) => { + if let Caller::Guest { instance, .. } = task.caller { + instance + } else { + unreachable!() + } + } + Err(_) => unreachable!(), + }, + }; + assert_eq!(expected_caller_instance, caller_instance); + Ok(()) } fn async_enter( @@ -216,19 +575,93 @@ unsafe impl VMComponentAsyncStore for StoreInner { params: u32, results: u32, ) -> Result<()> { - _ = ( - start, - return_, - caller_instance, - task_return_type, - params, - results, - ); - todo!() + let mut cx = StoreContextMut(self); + let start = SendSyncPtr::new(NonNull::new(start).unwrap()); + let return_ = SendSyncPtr::new(NonNull::new(return_).unwrap()); + let old_task = cx.concurrent_state().guest_task.take(); + let old_task_rep = old_task.map(|v| v.rep()); + let new_task = GuestTask { + lower_params: Some(Box::new(move |cx, dst| { + let mut cx = unsafe { StoreContextMut::(&mut *cx.cast()) }; + assert!(dst.len() <= MAX_FLAT_PARAMS); + let mut src = [MaybeUninit::uninit(); MAX_FLAT_PARAMS]; + src[0] = MaybeUninit::new(ValRaw::u32(params)); + unsafe { + crate::Func::call_unchecked_raw( + &mut cx, + start.as_non_null(), + NonNull::new( + &mut src[..1.max(dst.len())] as *mut [MaybeUninit] as _, + ) + .unwrap(), + )?; + } + dst.copy_from_slice(&src[..dst.len()]); + let task = cx.concurrent_state().guest_task.unwrap(); + if let Some(rep) = old_task_rep { + maybe_send_event( + cx, + TableId::new(rep), + Event::Started, + AnyTask::Guest(task), + 0, + )?; + } + Ok(()) + })), + lift_result: Some(( + Box::new(move |cx, src| { + let mut cx = unsafe { StoreContextMut::(&mut *cx.cast()) }; + let mut my_src = src.to_owned(); // TODO: use stack to avoid allocation? + my_src.push(ValRaw::u32(results)); + unsafe { + crate::Func::call_unchecked_raw( + &mut cx, + return_.as_non_null(), + my_src.as_mut_slice().into(), + )?; + } + let task = cx.concurrent_state().guest_task.unwrap(); + if let Some(rep) = old_task_rep { + maybe_send_event( + cx, + TableId::new(rep), + Event::Returned, + AnyTask::Guest(task), + 0, + )?; + } + Ok(Box::new(DummyResult) as Box) + }), + task_return_type, + )), + result: None, + callback: None, + caller: Caller::Guest { + task: old_task.unwrap(), + instance: caller_instance, + }, + deferred: Deferred::None, + events: VecDeque::new(), + should_yield: false, + call_context: Some(CallContext::default()), + }; + let guest_task = if let Some(old_task) = old_task { + let child = cx.concurrent_state().table.push_child(new_task, old_task)?; + log::trace!("new child of {}: {}", old_task.rep(), child.rep()); + child + } else { + cx.concurrent_state().table.push(new_task)? + }; + + cx.concurrent_state().guest_task = Some(guest_task); + + Ok(()) } fn async_exit( &mut self, + instance: &mut ComponentInstance, callback: *mut VMFuncRef, post_return: *mut VMFuncRef, caller_instance: RuntimeComponentInstanceIndex, @@ -238,16 +671,2243 @@ unsafe impl VMComponentAsyncStore for StoreInner { result_count: u32, flags: u32, ) -> Result { - _ = ( - callback, - post_return, - caller_instance, + let mut cx = StoreContextMut(self); + let guest_task = cx.concurrent_state().guest_task.unwrap(); + let callee = SendSyncPtr::new(NonNull::new(callee).unwrap()); + let param_count = usize::try_from(param_count).unwrap(); + assert!(param_count <= MAX_FLAT_PARAMS); + let result_count = usize::try_from(result_count).unwrap(); + assert!(result_count <= MAX_FLAT_RESULTS); + + let call = make_call( + guest_task, callee, callee_instance, param_count, result_count, - flags, + if callback.is_null() { + None + } else { + Some(instance.instance_flags(callee_instance)) + }, + ); + + let (guest_context, new_cx) = do_start_call( + cx, + instance, + guest_task, + (flags & EXIT_FLAG_ASYNC_CALLEE) != 0, + call, + NonNull::new(callback).map(SendSyncPtr::new), + NonNull::new(post_return).map(SendSyncPtr::new), + callee_instance, + result_count, + )?; + + cx = new_cx; + + let task = cx.concurrent_state().table.get(guest_task)?; + + let mut status = if task.lower_params.is_some() { + Status::Starting + } else if task.lift_result.is_some() { + Status::Started + } else if guest_context != 0 || callback.is_null() { + Status::Returned + } else { + Status::Done + }; + + let call = if status != Status::Done { + if (flags & EXIT_FLAG_ASYNC_CALLER) != 0 { + instance.component_waitable_tables()[caller_instance] + .insert(guest_task.rep(), WaitableState::Task)? + } else { + poll_for_result(cx)?; + status = Status::Done; + 0 + } + } else { + 0 + }; + + Ok(((status as u32) << 30) | call) + } + + fn future_new( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + ) -> Result { + futures_and_streams::guest_new(StoreContextMut(self), instance, TableIndex::Future(ty)) + } + + fn future_write( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeFutureTableIndex, + future: u32, + address: u32, + ) -> Result { + futures_and_streams::guest_write( + StoreContextMut(self), + instance, + memory, + realloc, + string_encoding, + TableIndex::Future(ty), + None, + future, + address, + 1, + ) + } + + fn future_read( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeFutureTableIndex, + future: u32, + address: u32, + ) -> Result { + futures_and_streams::guest_read( + StoreContextMut(self), + instance, + memory, + realloc, + string_encoding, + TableIndex::Future(ty), + None, + future, + address, + 1, + ) + } + + fn future_cancel_write( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + async_: bool, + writer: u32, + ) -> Result { + futures_and_streams::guest_cancel_write( + StoreContextMut(self), + instance, + TableIndex::Future(ty), + writer, + async_, + ) + } + + fn future_cancel_read( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + async_: bool, + reader: u32, + ) -> Result { + futures_and_streams::guest_cancel_read( + StoreContextMut(self), + instance, + TableIndex::Future(ty), + reader, + async_, + ) + } + + fn future_close_writable( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + writer: u32, + error: u32, + ) -> Result<()> { + futures_and_streams::guest_close_writable( + StoreContextMut(self), + instance, + TableIndex::Future(ty), + writer, + error, + ) + } + + fn future_close_readable( + &mut self, + instance: &mut ComponentInstance, + ty: TypeFutureTableIndex, + reader: u32, + ) -> Result<()> { + futures_and_streams::guest_close_readable( + StoreContextMut(self), + instance, + TableIndex::Future(ty), + reader, + ) + } + + fn stream_new( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + ) -> Result { + futures_and_streams::guest_new(StoreContextMut(self), instance, TableIndex::Stream(ty)) + } + + fn stream_write( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeStreamTableIndex, + stream: u32, + address: u32, + count: u32, + ) -> Result { + futures_and_streams::guest_write( + StoreContextMut(self), + instance, + memory, + realloc, + string_encoding, + TableIndex::Stream(ty), + None, + stream, + address, + count, + ) + } + + fn stream_read( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeStreamTableIndex, + stream: u32, + address: u32, + count: u32, + ) -> Result { + futures_and_streams::guest_read( + StoreContextMut(self), + instance, + memory, + realloc, + string_encoding, + TableIndex::Stream(ty), + None, + stream, + address, + count, + ) + } + + fn stream_cancel_write( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + async_: bool, + writer: u32, + ) -> Result { + futures_and_streams::guest_cancel_write( + StoreContextMut(self), + instance, + TableIndex::Stream(ty), + writer, + async_, + ) + } + + fn stream_cancel_read( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + async_: bool, + reader: u32, + ) -> Result { + futures_and_streams::guest_cancel_read( + StoreContextMut(self), + instance, + TableIndex::Stream(ty), + reader, + async_, + ) + } + + fn stream_close_writable( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + writer: u32, + error: u32, + ) -> Result<()> { + futures_and_streams::guest_close_writable( + StoreContextMut(self), + instance, + TableIndex::Stream(ty), + writer, + error, + ) + } + + fn stream_close_readable( + &mut self, + instance: &mut ComponentInstance, + ty: TypeStreamTableIndex, + reader: u32, + ) -> Result<()> { + futures_and_streams::guest_close_readable( + StoreContextMut(self), + instance, + TableIndex::Stream(ty), + reader, + ) + } + + fn flat_stream_write( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + ty: TypeStreamTableIndex, + payload_size: u32, + payload_align: u32, + stream: u32, + address: u32, + count: u32, + ) -> Result { + futures_and_streams::guest_write( + StoreContextMut(self), + instance, + memory, + realloc, + StringEncoding::Utf8 as u8, + TableIndex::Stream(ty), + Some(FlatAbi { + size: payload_size, + align: payload_align, + }), + stream, + address, + count, + ) + } + + fn flat_stream_read( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + ty: TypeStreamTableIndex, + payload_size: u32, + payload_align: u32, + stream: u32, + address: u32, + count: u32, + ) -> Result { + futures_and_streams::guest_read( + StoreContextMut(self), + instance, + memory, + realloc, + StringEncoding::Utf8 as u8, + TableIndex::Stream(ty), + Some(FlatAbi { + size: payload_size, + align: payload_align, + }), + stream, + address, + count, + ) + } + + fn error_context_new( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeComponentLocalErrorContextTableIndex, + debug_msg_address: u32, + debug_msg_len: u32, + ) -> Result { + futures_and_streams::error_context_new( + StoreContextMut(self), + instance, + memory, + realloc, + string_encoding, + ty, + debug_msg_address, + debug_msg_len, + ) + } + + fn error_context_debug_message( + &mut self, + instance: &mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeComponentLocalErrorContextTableIndex, + err_ctx_handle: u32, + debug_msg_address: u32, + ) -> Result<()> { + futures_and_streams::error_context_debug_message( + StoreContextMut(self), + instance, + memory, + realloc, + string_encoding, + ty, + err_ctx_handle, + debug_msg_address, + ) + } + + fn error_context_drop( + &mut self, + instance: &mut ComponentInstance, + ty: TypeComponentLocalErrorContextTableIndex, + err_ctx_handle: u32, + ) -> Result<()> { + futures_and_streams::error_context_drop(StoreContextMut(self), instance, ty, err_ctx_handle) + } +} + +struct HostTaskResult { + event: Event, + param: u32, + caller: TableId, +} + +type HostTaskFuture = Pin< + Box< + dyn Future< + Output = ( + u32, + Box Result + Send + Sync>, + ), + > + Send + + Sync + + 'static, + >, +>; + +struct HostTask { + caller_instance: RuntimeComponentInstanceIndex, +} + +enum Deferred { + None, + Stackful { + fiber: StoreFiber<'static>, + async_: bool, + }, + Stackless { + call: Box Result + Send + Sync + 'static>, + instance: RuntimeComponentInstanceIndex, + callback: SendSyncPtr, + }, +} + +impl Deferred { + fn take_stackful(&mut self) -> Option<(StoreFiber<'static>, bool)> { + if let Self::Stackful { .. } = self { + let Self::Stackful { fiber, async_ } = mem::replace(self, Self::None) else { + unreachable!() + }; + Some((fiber, async_)) + } else { + None + } + } +} + +#[derive(Copy, Clone)] +struct Callback { + function: SendSyncPtr, + context: u32, + instance: RuntimeComponentInstanceIndex, +} + +enum Caller { + Host(Option>), + Guest { + task: TableId, + instance: RuntimeComponentInstanceIndex, + }, +} + +struct GuestTask { + lower_params: Option, + lift_result: Option<(RawLift, TypeTupleIndex)>, + result: Option, + callback: Option, + events: VecDeque<(Event, AnyTask, u32)>, + caller: Caller, + deferred: Deferred, + should_yield: bool, + call_context: Option, +} + +impl Default for GuestTask { + fn default() -> Self { + Self { + lower_params: None, + lift_result: None, + result: None, + callback: None, + events: VecDeque::new(), + caller: Caller::Host(None), + deferred: Deferred::None, + should_yield: false, + call_context: Some(CallContext::default()), + } + } +} + +#[derive(Copy, Clone)] +enum AnyTask { + Host(TableId), + Guest(TableId), + Transmit(TableId), +} + +impl AnyTask { + fn rep(&self) -> u32 { + match self { + Self::Host(task) => task.rep(), + Self::Guest(task) => task.rep(), + Self::Transmit(task) => task.rep(), + } + } + + fn delete_all_from(&self, mut store: StoreContextMut) -> Result<()> { + match self { + Self::Host(task) => { + log::trace!("delete host task {}", task.rep()); + store.concurrent_state().table.delete(*task).map(drop) + } + Self::Guest(task) => { + let finished = store + .concurrent_state() + .table + .get(*task)? + .events + .iter() + .filter_map(|(event, call, _)| (*event == Event::Done).then_some(*call)) + .collect::>(); + + for call in finished { + log::trace!("will delete call {}", call.rep()); + call.delete_all_from(store.as_context_mut())?; + } + + log::trace!("delete guest task {}", task.rep()); + store.concurrent_state().table.delete(*task).map(drop) + } + Self::Transmit(task) => store.concurrent_state().table.delete(*task).map(drop), + }?; + + Ok(()) + } +} + +pub(crate) struct LiftLowerContext { + pub(crate) pointer: *mut u8, + pub(crate) dropper: fn(*mut u8), +} + +unsafe impl Send for LiftLowerContext {} +unsafe impl Sync for LiftLowerContext {} + +impl Drop for LiftLowerContext { + fn drop(&mut self) { + (self.dropper)(self.pointer); + } +} + +type RawLower = + Box]) -> Result<()> + Send + Sync>; + +type LowerFn = fn(LiftLowerContext, *mut dyn VMStore, &mut [MaybeUninit]) -> Result<()>; + +type RawLift = Box< + dyn FnOnce(*mut dyn VMStore, &[ValRaw]) -> Result> + Send + Sync, +>; + +type LiftFn = + fn(LiftLowerContext, *mut dyn VMStore, &[ValRaw]) -> Result>; + +type LiftedResult = Box; + +struct DummyResult; + +struct Reset(*mut T, T); + +impl Drop for Reset { + fn drop(&mut self) { + unsafe { + *self.0 = self.1; + } + } +} + +#[derive(Clone, Copy)] +struct PollContext { + future_context: *mut Context<'static>, + guard_range_start: *mut u8, + guard_range_end: *mut u8, +} + +impl Default for PollContext { + fn default() -> PollContext { + PollContext { + future_context: ptr::null_mut(), + guard_range_start: ptr::null_mut(), + guard_range_end: ptr::null_mut(), + } + } +} + +struct AsyncState { + current_suspend: UnsafeCell< + *mut Suspend< + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + >, + current_poll_cx: UnsafeCell, +} + +unsafe impl Send for AsyncState {} +unsafe impl Sync for AsyncState {} + +pub(crate) struct AsyncCx { + current_suspend: *mut *mut wasmtime_fiber::Suspend< + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + current_stack_limit: *mut usize, + current_poll_cx: *mut PollContext, + track_pkey_context_switch: bool, +} + +impl AsyncCx { + pub(crate) fn new(store: &mut StoreContextMut) -> Self { + Self::try_new(store).unwrap() + } + + pub(crate) fn try_new(store: &mut StoreContextMut) -> Option { + let current_poll_cx = store.concurrent_state().async_state.current_poll_cx.get(); + if unsafe { (*current_poll_cx).future_context.is_null() } { + None + } else { + Some(Self { + current_suspend: store.concurrent_state().async_state.current_suspend.get(), + current_stack_limit: store.0.runtime_limits().stack_limit.get(), + current_poll_cx, + track_pkey_context_switch: store.has_pkey(), + }) + } + } + + unsafe fn poll(&self, mut future: Pin<&mut (dyn Future + Send)>) -> Poll { + let poll_cx = *self.current_poll_cx; + let _reset = Reset(self.current_poll_cx, poll_cx); + *self.current_poll_cx = PollContext::default(); + assert!(!poll_cx.future_context.is_null()); + future.as_mut().poll(&mut *poll_cx.future_context) + } + + pub(crate) unsafe fn block_on<'a, T, U>( + &self, + mut future: Pin<&mut (dyn Future + Send)>, + mut store: Option>, + ) -> Result<(U, Option>)> { + loop { + match self.poll(future.as_mut()) { + Poll::Ready(v) => break Ok((v, store)), + Poll::Pending => {} + } + + store = self.suspend(store)?; + } + } + + unsafe fn suspend<'a, T>( + &self, + store: Option>, + ) -> Result>> { + let previous_mask = if self.track_pkey_context_switch { + let previous_mask = mpk::current_mask(); + mpk::allow(ProtectionMask::all()); + previous_mask + } else { + ProtectionMask::all() + }; + let store = suspend_fiber(self.current_suspend, self.current_stack_limit, store); + if self.track_pkey_context_switch { + mpk::allow(previous_mask); + } + store + } +} + +#[derive(Default)] +struct InstanceState { + backpressure: bool, + in_sync_call: bool, + task_queue: VecDeque>, +} + +pub struct ConcurrentState { + guest_task: Option>, + futures: ReadyChunks>, + table: Table, + async_state: AsyncState, + // TODO: this can and should be a `PrimaryMap` + instance_states: HashMap, + yielding: HashSet, + unblocked: HashSet, + component_instance: Option>, + _phantom: PhantomData, +} + +impl ConcurrentState { + pub(crate) fn async_guard_range(&self) -> Range<*mut u8> { + let context = unsafe { *self.async_state.current_poll_cx.get() }; + context.guard_range_start..context.guard_range_end + } +} + +impl Default for ConcurrentState { + fn default() -> Self { + Self { + guest_task: None, + table: Table::new(), + futures: ReadyChunks::new(FuturesUnordered::new(), 1024), + async_state: AsyncState { + current_suspend: UnsafeCell::new(ptr::null_mut()), + current_poll_cx: UnsafeCell::new(PollContext::default()), + }, + instance_states: HashMap::new(), + yielding: HashSet::new(), + unblocked: HashSet::new(), + component_instance: None, + _phantom: PhantomData, + } + } +} + +fn dummy_waker() -> Waker { + struct DummyWaker; + + impl Wake for DummyWaker { + fn wake(self: Arc) {} + } + + static WAKER: Lazy> = Lazy::new(|| Arc::new(DummyWaker)); + + WAKER.clone().into() +} + +/// Provide a hint to Rust type inferencer that we're returning a compatible +/// closure from a `LinkerInstance::func_wrap_concurrent` future. +pub fn for_any(fun: F) -> F +where + F: FnOnce(StoreContextMut) -> R + 'static, + R: 'static, +{ + fun +} + +fn for_any_lower< + F: FnOnce(*mut dyn VMStore, &mut [MaybeUninit]) -> Result<()> + Send + Sync, +>( + fun: F, +) -> F { + fun +} + +fn for_any_lift< + F: FnOnce(*mut dyn VMStore, &[ValRaw]) -> Result> + Send + Sync, +>( + fun: F, +) -> F { + fun +} + +pub(crate) fn first_poll( + instance: *mut ComponentInstance, + mut store: StoreContextMut, + future: impl Future) -> Result + Send + Sync + 'static> + + Send + + Sync + + 'static, + caller_instance: RuntimeComponentInstanceIndex, + lower: impl FnOnce(StoreContextMut, R) -> Result<()> + Send + Sync + 'static, +) -> Result> { + let caller = store.concurrent_state().guest_task.unwrap(); + let task = store + .concurrent_state() + .table + .push_child(HostTask { caller_instance }, caller)?; + log::trace!("new child of {}: {}", caller.rep(), task.rep()); + let mut future = Box::pin(future.map(move |fun| { + ( + task.rep(), + Box::new(move |store: *mut dyn VMStore| { + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + let result = fun(store.as_context_mut())?; + lower(store, result)?; + Ok(HostTaskResult { + event: Event::Done, + param: 0u32, + caller, + }) + }) + as Box Result + Send + Sync>, + ) + })) as HostTaskFuture; + + Ok( + match future + .as_mut() + .poll(&mut Context::from_waker(&dummy_waker())) + { + Poll::Ready((_, fun)) => { + log::trace!("delete host task {} (already ready)", task.rep()); + store.concurrent_state().table.delete(task)?; + fun(store.0.traitobj().as_ptr())?; + None + } + Poll::Pending => { + store.concurrent_state().futures.get_mut().push(future); + Some( + unsafe { &mut *instance }.component_waitable_tables()[caller_instance] + .insert(task.rep(), WaitableState::Task)?, + ) + } + }, + ) +} + +pub(crate) fn poll_and_block<'a, T, R: Send + Sync + 'static>( + mut store: StoreContextMut<'a, T>, + future: impl Future) -> Result + Send + Sync + 'static> + + Send + + Sync + + 'static, + caller_instance: RuntimeComponentInstanceIndex, +) -> Result<(R, StoreContextMut<'a, T>)> { + let Some(caller) = store.concurrent_state().guest_task else { + return match pin!(future).poll(&mut Context::from_waker(&dummy_waker())) { + Poll::Ready(fun) => { + let result = fun(store.as_context_mut())?; + Ok((result, store)) + } + Poll::Pending => { + unreachable!() + } + }; + }; + let old_result = store + .concurrent_state() + .table + .get_mut(caller) + .with_context(|| format!("bad handle: {}", caller.rep()))? + .result + .take(); + let task = store + .concurrent_state() + .table + .push_child(HostTask { caller_instance }, caller)?; + log::trace!("new child of {}: {}", caller.rep(), task.rep()); + let mut future = Box::pin(future.map(move |fun| { + ( + task.rep(), + Box::new(move |store: *mut dyn VMStore| { + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + let result = fun(store.as_context_mut())?; + store.concurrent_state().table.get_mut(caller)?.result = + Some(Box::new(result) as _); + Ok(HostTaskResult { + event: Event::Done, + param: 0u32, + caller, + }) + }) + as Box Result + Send + Sync>, + ) + })) as HostTaskFuture; + + Ok( + match unsafe { AsyncCx::new(&mut store).poll(future.as_mut()) } { + Poll::Ready((_, fun)) => { + log::trace!("delete host task {} (already ready)", task.rep()); + store.concurrent_state().table.delete(task)?; + let store = store.0.traitobj().as_ptr(); + fun(store)?; + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + let result = *mem::replace( + &mut store.concurrent_state().table.get_mut(caller)?.result, + old_result, + ) + .unwrap() + .downcast() + .unwrap(); + (result, store) + } + Poll::Pending => { + store.concurrent_state().futures.get_mut().push(future); + loop { + if let Some(result) = store + .concurrent_state() + .table + .get_mut(caller)? + .result + .take() + { + store.concurrent_state().table.get_mut(caller)?.result = old_result; + break (*result.downcast().unwrap(), store); + } else { + let async_cx = AsyncCx::new(&mut store); + store = unsafe { async_cx.suspend(Some(store)) }?.unwrap(); + } + } + } + }, + ) +} + +pub(crate) async fn on_fiber<'a, R: Send + Sync + 'static, T: Send>( + mut store: StoreContextMut<'a, T>, + instance: Option, + func: impl FnOnce(&mut StoreContextMut) -> R + Send, +) -> Result<(R, StoreContextMut<'a, T>)> { + let result = Arc::new(Mutex::new(None)); + let mut fiber = make_fiber(&mut store, instance, { + let result = result.clone(); + move |mut store| { + *result.lock().unwrap() = Some(func(&mut store)); + Ok(()) + } + })?; + + let guard_range = fiber + .fiber + .as_ref() + .unwrap() + .stack() + .guard_range() + .map(|r| { + ( + NonNull::new(r.start).map(SendSyncPtr::new), + NonNull::new(r.end).map(SendSyncPtr::new), + ) + }) + .unwrap_or((None, None)); + + store = poll_fn(store, guard_range, move |_, mut store| { + match resume_fiber(&mut fiber, store.take(), Ok(())) { + Ok(Ok((store, result))) => Ok(result.map(|()| store)), + Ok(Err(s)) => Err(s), + Err(e) => Ok(Err(e)), + } + }) + .await?; + + let result = result.lock().unwrap().take().unwrap(); + Ok((result, store)) +} + +fn maybe_push_call_context( + store: &mut StoreContextMut, + guest_task: TableId, +) -> Result<()> { + let task = store.concurrent_state().table.get_mut(guest_task)?; + if task.lift_result.is_some() { + log::trace!("push call context for {}", guest_task.rep()); + let call_context = task.call_context.take().unwrap(); + store.0.component_resource_state().0.push(call_context); + } + Ok(()) +} + +fn maybe_pop_call_context( + store: &mut StoreContextMut, + guest_task: TableId, +) -> Result<()> { + if store + .concurrent_state() + .table + .get_mut(guest_task)? + .lift_result + .is_some() + { + log::trace!("pop call context for {}", guest_task.rep()); + let call_context = Some(store.0.component_resource_state().0.pop().unwrap()); + store + .concurrent_state() + .table + .get_mut(guest_task)? + .call_context = call_context; + } + Ok(()) +} + +fn maybe_send_event<'a, T>( + mut store: StoreContextMut<'a, T>, + guest_task: TableId, + event: Event, + call: AnyTask, + result: u32, +) -> Result> { + assert_ne!(guest_task.rep(), call.rep()); + if let Some(callback) = store.concurrent_state().table.get(guest_task)?.callback { + let old_task = store.concurrent_state().guest_task.replace(guest_task); + let Some((handle, _)) = unsafe { + &mut *store + .concurrent_state() + .component_instance + .unwrap() + .as_ptr() + } + .component_waitable_tables()[callback.instance] + .get_mut_by_rep(call.rep()) + else { + bail!("handle not found for waitable rep {}", call.rep()); + }; + log::trace!( + "use callback to deliver event {event:?} to {} for {} (handle {handle}): {:?} {}", + guest_task.rep(), + call.rep(), + callback.function, + callback.context ); - todo!() + + maybe_push_call_context(&mut store, guest_task)?; + + let mut flags = unsafe { + (*store + .concurrent_state() + .component_instance + .unwrap() + .as_ptr()) + .instance_flags(callback.instance) + }; + + let params = &mut [ + ValRaw::u32(callback.context), + ValRaw::u32(event as u32), + ValRaw::u32(handle), + ValRaw::u32(result), + ]; + unsafe { + flags.set_may_enter(false); + crate::Func::call_unchecked_raw( + &mut store, + callback.function.as_non_null(), + params.as_mut_slice().into(), + )?; + flags.set_may_enter(true); + } + + maybe_pop_call_context(&mut store, guest_task)?; + + let done = params[0].get_u32() != 0; + log::trace!("{} done? {done}", guest_task.rep()); + if done { + store.concurrent_state().table.get_mut(guest_task)?.callback = None; + + match &store.concurrent_state().table.get(guest_task)?.caller { + Caller::Guest { task, .. } => { + let task = *task; + store = + maybe_send_event(store, task, Event::Done, AnyTask::Guest(guest_task), 0)?; + } + Caller::Host(_) => { + log::trace!("maybe_send_event will delete {}", call.rep()); + AnyTask::Guest(guest_task).delete_all_from(store.as_context_mut())?; + } + } + } + store.concurrent_state().guest_task = old_task; + Ok(store) + } else { + store + .concurrent_state() + .table + .get_mut(guest_task)? + .events + .push_back((event, call, result)); + + let resumed = if event == Event::Done { + if let Some((fiber, async_)) = store + .concurrent_state() + .table + .get_mut(guest_task)? + .deferred + .take_stackful() + { + log::trace!( + "use fiber to deliver event {event:?} to {} for {}", + guest_task.rep(), + call.rep() + ); + let old_task = store.concurrent_state().guest_task.replace(guest_task); + store = resume_stackful(store, guest_task, fiber, async_)?; + store.concurrent_state().guest_task = old_task; + true + } else { + false + } + } else { + false + }; + + if !resumed { + log::trace!( + "queue event {event:?} to {} for {}", + guest_task.rep(), + call.rep() + ); + } + + Ok(store) + } +} + +fn resume_stackful<'a, T>( + mut store: StoreContextMut<'a, T>, + guest_task: TableId, + mut fiber: StoreFiber<'static>, + async_: bool, +) -> Result> { + maybe_push_call_context(&mut store, guest_task)?; + + match resume_fiber(&mut fiber, Some(store), Ok(()))? { + Ok((mut store, result)) => { + result?; + if async_ { + if store + .concurrent_state() + .table + .get(guest_task)? + .lift_result + .is_some() + { + return Err(anyhow!(crate::Trap::NoAsyncResult)); + } + } + if let Some(instance) = fiber.instance { + store = maybe_resume_next_task(store, instance)?; + for (event, call, _) in mem::take( + &mut store + .concurrent_state() + .table + .get_mut(guest_task) + .with_context(|| format!("bad handle: {}", guest_task.rep()))? + .events, + ) { + if event == Event::Done { + log::trace!("resume_stackful will delete call {}", call.rep()); + call.delete_all_from(store.as_context_mut())?; + } + } + match &store.concurrent_state().table.get(guest_task)?.caller { + Caller::Host(_) => { + log::trace!("resume_stackful will delete task {}", guest_task.rep()); + AnyTask::Guest(guest_task).delete_all_from(store.as_context_mut())?; + Ok(store) + } + Caller::Guest { task, .. } => { + let task = *task; + maybe_send_event(store, task, Event::Done, AnyTask::Guest(guest_task), 0) + } + } + } else { + Ok(store) + } + } + Err(new_store) => { + store = new_store.unwrap(); + maybe_pop_call_context(&mut store, guest_task)?; + store.concurrent_state().table.get_mut(guest_task)?.deferred = + Deferred::Stackful { fiber, async_ }; + Ok(store) + } + } +} + +fn resume_stackless<'a, T>( + mut store: StoreContextMut<'a, T>, + guest_task: TableId, + call: Box Result>, + instance: RuntimeComponentInstanceIndex, + callback: SendSyncPtr, +) -> Result> { + maybe_push_call_context(&mut store, guest_task)?; + + let store = store.0.traitobj().as_ptr(); + let guest_context = call(store)?; + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + + maybe_pop_call_context(&mut store, guest_task)?; + + let task = store.concurrent_state().table.get_mut(guest_task)?; + let event = if task.lift_result.is_some() { + if guest_context == 0 { + return Err(anyhow!(crate::Trap::NoAsyncResult)); + } + Event::Started + } else if guest_context != 0 { + Event::Returned + } else { + Event::Done + }; + if guest_context != 0 { + log::trace!("set callback for {}", guest_task.rep()); + task.callback = Some(Callback { + function: callback, + instance, + context: guest_context, + }); + for (event, call, result) in mem::take(&mut task.events) { + store = maybe_send_event(store, guest_task, event, call, result)?; + } + } + store = maybe_resume_next_task(store, instance)?; + if let Caller::Guest { task, .. } = &store.concurrent_state().table.get(guest_task)?.caller { + let task = *task; + maybe_send_event(store, task, event, AnyTask::Guest(guest_task), 0) + } else { + Ok(store) + } +} + +fn poll_for_result<'a, T>(mut store: StoreContextMut<'a, T>) -> Result> { + let task = store.concurrent_state().guest_task; + poll_loop(store, move |store| { + task.map(|task| { + Ok::<_, anyhow::Error>(store.concurrent_state().table.get(task)?.result.is_none()) + }) + .unwrap_or(Ok(true)) + }) +} + +fn handle_ready<'a, T>( + mut store: StoreContextMut<'a, T>, + ready: Vec<( + u32, + Box Result + Send + Sync>, + )>, +) -> Result> { + for (task, fun) in ready { + let vm_store = store.0.traitobj().as_ptr(); + let result = fun(vm_store)?; + store = unsafe { StoreContextMut::(&mut *vm_store.cast()) }; + let task = match result.event { + Event::Done => AnyTask::Host(TableId::::new(task)), + Event::StreamRead | Event::FutureRead | Event::StreamWrite | Event::FutureWrite => { + AnyTask::Transmit(TableId::::new(task)) + } + _ => unreachable!(), + }; + store = maybe_send_event(store, result.caller, result.event, task, result.param)?; + } + Ok(store) +} + +fn maybe_yield<'a, T>(mut store: StoreContextMut<'a, T>) -> Result> { + let guest_task = store.concurrent_state().guest_task.unwrap(); + + if store.concurrent_state().table.get(guest_task)?.should_yield { + log::trace!("maybe_yield suspend {}", guest_task.rep()); + + store.concurrent_state().yielding.insert(guest_task.rep()); + let cx = AsyncCx::new(&mut store); + store = unsafe { cx.suspend(Some(store)) }?.unwrap(); + + log::trace!("maybe_yield resume {}", guest_task.rep()); + } else { + log::trace!("maybe_yield skip {}", guest_task.rep()); + } + + Ok(store) +} + +fn unyield<'a, T>(mut store: StoreContextMut<'a, T>) -> Result<(StoreContextMut<'a, T>, bool)> { + let mut resumed = false; + for task in mem::take(&mut store.concurrent_state().yielding) { + let guest_task = TableId::::new(task); + if let Some((fiber, async_)) = store + .concurrent_state() + .table + .get_mut(guest_task)? + .deferred + .take_stackful() + { + resumed = true; + let old_task = store.concurrent_state().guest_task.replace(guest_task); + store = resume_stackful(store, guest_task, fiber, async_)?; + store.concurrent_state().guest_task = old_task; + } + } + + for instance in mem::take(&mut store.concurrent_state().unblocked) { + let entry = store + .concurrent_state() + .instance_states + .entry(instance) + .or_default(); + + if !(entry.backpressure || entry.in_sync_call) { + if let Some(task) = entry.task_queue.pop_front() { + resumed = true; + store = resume(store, task)?; + } + } + } + + Ok((store, resumed)) +} + +fn poll_loop<'a, T>( + mut store: StoreContextMut<'a, T>, + mut continue_: impl FnMut(&mut StoreContextMut<'a, T>) -> Result, +) -> Result> { + loop { + let cx = AsyncCx::new(&mut store); + let mut future = pin!(store.concurrent_state().futures.next()); + let ready = unsafe { cx.poll(future.as_mut()) }; + + match ready { + Poll::Ready(Some(ready)) => { + store = handle_ready(store, ready)?; + } + Poll::Ready(None) => { + let (s, resumed) = unyield(store)?; + store = s; + if !resumed { + log::trace!("exhausted future queue; exiting poll_loop"); + break; + } + } + Poll::Pending => { + let (s, resumed) = unyield(store)?; + store = s; + if continue_(&mut store)? { + let cx = AsyncCx::new(&mut store); + store = unsafe { cx.suspend(Some(store)) }?.unwrap(); + } else if !resumed { + break; + } + } + } + } + + Ok(store) +} + +fn resume<'a, T>( + mut store: StoreContextMut<'a, T>, + task: TableId, +) -> Result> { + log::trace!("resume {}", task.rep()); + + // TODO: Avoid calling `resume_stackful` or `resume_stackless` here, because it may call us, leading to + // recursion limited only by the number of waiters. Flatten this into an iteration instead. + let old_task = store.concurrent_state().guest_task.replace(task); + store = match mem::replace( + &mut store.concurrent_state().table.get_mut(task)?.deferred, + Deferred::None, + ) { + Deferred::None => unreachable!(), + Deferred::Stackful { fiber, async_ } => resume_stackful(store, task, fiber, async_), + Deferred::Stackless { + call, + instance, + callback, + } => resume_stackless(store, task, call, instance, callback), + }?; + store.concurrent_state().guest_task = old_task; + Ok(store) +} + +fn maybe_resume_next_task<'a, T>( + mut store: StoreContextMut<'a, T>, + instance: RuntimeComponentInstanceIndex, +) -> Result> { + let state = store + .concurrent_state() + .instance_states + .get_mut(&instance) + .unwrap(); + + if state.backpressure || state.in_sync_call { + Ok(store) + } else { + if let Some(next) = state.task_queue.pop_front() { + resume(store, next) + } else { + Ok(store) + } } } + +struct StoreFiber<'a> { + fiber: Option< + Fiber< + 'a, + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + >, + state: Option, + engine: Engine, + suspend: *mut *mut Suspend< + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + stack_limit: *mut usize, + instance: Option, +} + +impl<'a> Drop for StoreFiber<'a> { + fn drop(&mut self) { + if !self.fiber.as_ref().unwrap().done() { + let result = unsafe { resume_fiber_raw(self, None, Err(anyhow!("future dropped"))) }; + debug_assert!(result.is_ok()); + } + + self.state.take().unwrap().assert_null(); + + unsafe { + self.engine + .allocator() + .deallocate_fiber_stack(self.fiber.take().unwrap().into_stack()); + } + } +} + +unsafe impl<'a> Send for StoreFiber<'a> {} +unsafe impl<'a> Sync for StoreFiber<'a> {} + +fn make_fiber<'a, T>( + store: &mut StoreContextMut, + instance: Option, + fun: impl FnOnce(StoreContextMut) -> Result<()> + 'a, +) -> Result> { + let engine = store.engine().clone(); + let stack = engine.allocator().allocate_fiber_stack()?; + Ok(StoreFiber { + fiber: Some(Fiber::new( + stack, + move |(store_ptr, result): (Option<*mut dyn VMStore>, Result<()>), suspend| { + if result.is_err() { + (store_ptr, result) + } else { + unsafe { + let store_ptr = store_ptr.unwrap(); + let mut store = StoreContextMut(&mut *store_ptr.cast()); + let suspend_ptr = + store.concurrent_state().async_state.current_suspend.get(); + let _reset = Reset(suspend_ptr, *suspend_ptr); + *suspend_ptr = suspend; + (Some(store_ptr), fun(store.as_context_mut())) + } + } + }, + )?), + state: Some(AsyncWasmCallState::new()), + engine, + suspend: store.concurrent_state().async_state.current_suspend.get(), + stack_limit: store.0.runtime_limits().stack_limit.get(), + instance, + }) +} + +unsafe fn resume_fiber_raw<'a>( + fiber: *mut StoreFiber<'a>, + store: Option<*mut dyn VMStore>, + result: Result<()>, +) -> Result<(Option<*mut dyn VMStore>, Result<()>), Option<*mut dyn VMStore>> { + struct Restore<'a> { + fiber: *mut StoreFiber<'a>, + state: Option, + } + + impl Drop for Restore<'_> { + fn drop(&mut self) { + unsafe { + (*self.fiber).state = Some(self.state.take().unwrap().restore()); + } + } + } + + let _reset_suspend = Reset((*fiber).suspend, *(*fiber).suspend); + let _reset_stack_limit = Reset((*fiber).stack_limit, *(*fiber).stack_limit); + let state = Some((*fiber).state.take().unwrap().push()); + let restore = Restore { fiber, state }; + (*restore.fiber) + .fiber + .as_ref() + .unwrap() + .resume((store, result)) +} + +fn poll_ready<'a, T>(mut store: StoreContextMut<'a, T>) -> Result> { + unsafe { + let cx = *store.concurrent_state().async_state.current_poll_cx.get(); + assert!(!cx.future_context.is_null()); + while let Poll::Ready(Some(ready)) = store + .concurrent_state() + .futures + .poll_next_unpin(&mut *cx.future_context) + { + match handle_ready(store, ready) { + Ok(s) => { + store = s; + } + Err(e) => { + return Err(e); + } + } + } + } + Ok(store) +} + +fn resume_fiber<'a, T>( + fiber: &mut StoreFiber, + mut store: Option>, + result: Result<()>, +) -> Result, Result<()>), Option>>> { + if let Some(s) = store.take() { + store = Some(poll_ready(s)?); + } + + unsafe { + match resume_fiber_raw(fiber, store.map(|s| s.0.traitobj().as_ptr()), result) + .map(|(store, result)| (StoreContextMut(&mut *store.unwrap().cast()), result)) + .map_err(|v| v.map(|v| StoreContextMut(&mut *v.cast()))) + { + Ok(pair) => Ok(Ok(pair)), + Err(s) => { + if let Some(range) = fiber.fiber.as_ref().unwrap().stack().range() { + AsyncWasmCallState::assert_current_state_not_in_range(range); + } + + Ok(Err(s)) + } + } + } +} + +unsafe fn suspend_fiber<'a, T>( + suspend: *mut *mut Suspend< + (Option<*mut dyn VMStore>, Result<()>), + Option<*mut dyn VMStore>, + (Option<*mut dyn VMStore>, Result<()>), + >, + stack_limit: *mut usize, + store: Option>, +) -> Result>> { + let _reset_suspend = Reset(suspend, *suspend); + let _reset_stack_limit = Reset(stack_limit, *stack_limit); + let (store, result) = (**suspend).suspend(store.map(|s| s.0.traitobj().as_ptr())); + result?; + Ok(store.map(|v| StoreContextMut(&mut *v.cast()))) +} + +enum TaskCheck { + Wait(*mut VMMemoryDefinition, u32, RuntimeComponentInstanceIndex), + Poll(*mut VMMemoryDefinition, u32, RuntimeComponentInstanceIndex), + Yield, +} + +fn task_check( + mut cx: StoreContextMut, + instance: *mut ComponentInstance, + async_: bool, + check: TaskCheck, +) -> Result { + if async_ { + bail!("todo: async `task.wait`, `task.poll`, and `task.yield` not yet implemented"); + } + + let guest_task = cx.concurrent_state().guest_task.unwrap(); + + log::trace!("task check for {}", guest_task.rep()); + + let wait = matches!(check, TaskCheck::Wait(..)); + + if wait + && cx + .concurrent_state() + .table + .get(guest_task)? + .callback + .is_some() + { + bail!("cannot call `task.wait` from async-lifted export with callback"); + } + + if matches!(check, TaskCheck::Yield) + || cx + .concurrent_state() + .table + .get(guest_task)? + .events + .is_empty() + { + cx = maybe_yield(cx)?; + + if cx + .concurrent_state() + .table + .get(guest_task)? + .events + .is_empty() + { + cx = poll_loop(cx, move |cx| { + Ok::<_, anyhow::Error>( + wait && cx + .concurrent_state() + .table + .get(guest_task)? + .events + .is_empty(), + ) + })?; + } + } + + log::trace!("task check for {}, part two", guest_task.rep()); + + let result = match check { + TaskCheck::Wait(memory, payload, caller_instance) => { + let (event, call, result) = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .events + .pop_front() + .ok_or_else(|| anyhow!("no tasks to wait for"))?; + + log::trace!( + "deliver event {event:?} via task.wait to {} for {}", + guest_task.rep(), + call.rep() + ); + + let entry = unsafe { + (*instance).component_waitable_tables()[caller_instance].get_mut_by_rep(call.rep()) + }; + let Some((handle, _)) = entry else { + bail!("handle not found for waitable rep {}", call.rep()); + }; + + let options = unsafe { + Options::new( + cx.0.id(), + NonNull::new(memory), + None, + StringEncoding::Utf8, + true, + None, + ) + }; + let types = unsafe { (*instance).component_types() }; + let ptr = + func::validate_inbounds::(options.memory_mut(cx.0), &ValRaw::u32(payload))?; + let mut lower = unsafe { LowerContext::new(cx, &options, types, instance) }; + handle.store(&mut lower, InterfaceType::U32, ptr)?; + result.store(&mut lower, InterfaceType::U32, ptr + 4)?; + + Ok(event as u32) + } + TaskCheck::Poll(memory, payload, caller_instance) => { + if let Some((event, call, result)) = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .events + .pop_front() + { + let entry = unsafe { + (*instance).component_waitable_tables()[caller_instance] + .get_mut_by_rep(call.rep()) + }; + let Some((handle, _)) = entry else { + bail!("handle not found for waitable rep {}", call.rep()); + }; + + let options = unsafe { + Options::new( + cx.0.id(), + NonNull::new(memory), + None, + StringEncoding::Utf8, + true, + None, + ) + }; + let types = unsafe { (*instance).component_types() }; + let ptr = func::validate_inbounds::<(u32, u32)>( + options.memory_mut(cx.0), + &ValRaw::u32(payload), + )?; + let mut lower = unsafe { LowerContext::new(cx, &options, types, instance) }; + (event as u32).store(&mut lower, InterfaceType::U32, ptr)?; + handle.store(&mut lower, InterfaceType::U32, ptr + 4)?; + result.store(&mut lower, InterfaceType::U32, ptr + 8)?; + + Ok(1) + } else { + log::trace!( + "no events ready to deliver via task.poll to {}", + guest_task.rep() + ); + + Ok(0) + } + } + TaskCheck::Yield => Ok(0), + }; + + result +} + +fn may_enter( + store: &mut StoreContextMut, + mut guest_task: TableId, + guest_instance: RuntimeComponentInstanceIndex, +) -> bool { + // Walk the task tree back to the root, looking for potential reentrance. + // + // TODO: This could be optimized by maintaining a per-`GuestTask` bitset + // such that each bit represents and instance which has been entered by that + // task or an ancestor of that task, in which case this would be a constant + // time check. + loop { + match &store + .concurrent_state() + .table + .get_mut(guest_task) + .unwrap() + .caller + { + Caller::Host(_) => break true, + Caller::Guest { task, instance } => { + if *instance == guest_instance { + break false; + } else { + guest_task = *task; + } + } + } + } +} + +fn make_call( + guest_task: TableId, + callee: SendSyncPtr, + callee_instance: RuntimeComponentInstanceIndex, + param_count: usize, + result_count: usize, + flags: Option, +) -> impl FnOnce( + StoreContextMut, +) -> Result<([MaybeUninit; MAX_FLAT_PARAMS], StoreContextMut)> + + Send + + Sync + + 'static { + move |mut cx: StoreContextMut| { + if !may_enter(&mut cx, guest_task, callee_instance) { + bail!(crate::Trap::CannotEnterComponent); + } + + let mut storage = [MaybeUninit::uninit(); MAX_FLAT_PARAMS]; + let lower = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .lower_params + .take() + .unwrap(); + let cx = cx.0.traitobj().as_ptr(); + lower(cx, &mut storage[..param_count])?; + let mut cx = unsafe { StoreContextMut::(&mut *cx.cast()) }; + + unsafe { + if let Some(mut flags) = flags { + flags.set_may_enter(false); + } + crate::Func::call_unchecked_raw( + &mut cx, + callee.as_non_null(), + NonNull::new( + &mut storage[..param_count.max(result_count)] as *mut [MaybeUninit] + as _, + ) + .unwrap(), + )?; + if let Some(mut flags) = flags { + flags.set_may_enter(true); + } + } + + Ok((storage, cx)) + } +} + +fn do_start_call<'a, T>( + mut cx: StoreContextMut<'a, T>, + instance: *mut ComponentInstance, + guest_task: TableId, + async_: bool, + call: impl FnOnce( + StoreContextMut, + ) -> Result<([MaybeUninit; MAX_FLAT_PARAMS], StoreContextMut)> + + Send + + Sync + + 'static, + callback: Option>, + post_return: Option>, + callee_instance: RuntimeComponentInstanceIndex, + result_count: usize, +) -> Result<(u32, StoreContextMut<'a, T>)> { + let state = &mut cx + .concurrent_state() + .instance_states + .entry(callee_instance) + .or_default(); + let ready = state.task_queue.is_empty() && !(state.backpressure || state.in_sync_call); + + let mut guest_context = 0; + let mut async_finished = false; + + let mut cx = if let Some(callback) = callback { + assert!(async_); + + if ready { + maybe_push_call_context(&mut cx, guest_task)?; + let (storage, mut cx) = call(cx)?; + guest_context = unsafe { storage[0].assume_init() }.get_i32() as u32; + async_finished = guest_context == 0; + maybe_pop_call_context(&mut cx, guest_task)?; + cx + } else { + cx.concurrent_state() + .instance_states + .get_mut(&callee_instance) + .unwrap() + .task_queue + .push_back(guest_task); + + cx.concurrent_state().table.get_mut(guest_task)?.deferred = Deferred::Stackless { + call: Box::new(move |cx| { + let mut cx = unsafe { StoreContextMut(&mut *cx.cast()) }; + let old_task = cx.concurrent_state().guest_task.replace(guest_task); + let (storage, mut cx) = call(cx)?; + cx.concurrent_state().guest_task = old_task; + Ok(unsafe { storage[0].assume_init() }.get_i32() as u32) + }), + instance: callee_instance, + callback, + }; + cx + } + } else { + let mut fiber = make_fiber(&mut cx, Some(callee_instance), move |mut cx| { + let mut flags = unsafe { (*instance).instance_flags(callee_instance) }; + + if !async_ { + cx.concurrent_state() + .instance_states + .get_mut(&callee_instance) + .unwrap() + .in_sync_call = true; + } + + let (storage, mut cx) = call(cx)?; + + if !async_ { + cx.concurrent_state() + .instance_states + .get_mut(&callee_instance) + .unwrap() + .in_sync_call = false; + + let (lift, _) = cx + .concurrent_state() + .table + .get_mut(guest_task)? + .lift_result + .take() + .unwrap(); + + assert!(cx + .concurrent_state() + .table + .get(guest_task)? + .result + .is_none()); + + let cx = cx.0.traitobj().as_ptr(); + let result = lift(cx, unsafe { + mem::transmute::<&[MaybeUninit], &[ValRaw]>(&storage[..result_count]) + })?; + let mut cx = unsafe { StoreContextMut::(&mut *cx.cast()) }; + + unsafe { flags.set_needs_post_return(false) } + + if let Some(func) = post_return { + let arg = match result_count { + 0 => ValRaw::i32(0), + 1 => unsafe { storage[0].assume_init() }, + _ => unreachable!(), + }; + unsafe { + crate::Func::call_unchecked_raw( + &mut cx, + func.as_non_null(), + NonNull::new(ptr::slice_from_raw_parts(&arg, 1).cast_mut()).unwrap(), + )?; + } + } + + unsafe { flags.set_may_enter(true) } + + let (calls, host_table, _) = cx.0.component_resource_state(); + ResourceTables { + calls, + host_table: Some(host_table), + tables: unsafe { Some((*instance).component_resource_tables()) }, + } + .exit_call()?; + + if let Caller::Host(tx) = + &mut cx.concurrent_state().table.get_mut(guest_task)?.caller + { + _ = tx.take().unwrap().send(result); + } else { + cx.concurrent_state().table.get_mut(guest_task)?.result = Some(result); + } + } + + Ok(()) + })?; + + cx.concurrent_state() + .table + .get_mut(guest_task)? + .should_yield = true; + + if ready { + maybe_push_call_context(&mut cx, guest_task)?; + let mut cx = Some(cx); + loop { + match resume_fiber(&mut fiber, cx.take(), Ok(()))? { + Ok((cx, result)) => { + async_finished = async_; + result?; + break maybe_resume_next_task(cx, callee_instance)?; + } + Err(cx) => { + if let Some(mut cx) = cx { + maybe_pop_call_context(&mut cx, guest_task)?; + cx.concurrent_state().table.get_mut(guest_task)?.deferred = + Deferred::Stackful { fiber, async_ }; + break cx; + } else { + unsafe { suspend_fiber::(fiber.suspend, fiber.stack_limit, None)? }; + } + } + } + } + } else { + cx.concurrent_state() + .instance_states + .get_mut(&callee_instance) + .unwrap() + .task_queue + .push_back(guest_task); + + cx.concurrent_state().table.get_mut(guest_task)?.deferred = + Deferred::Stackful { fiber, async_ }; + cx + } + }; + + let guest_task = cx.concurrent_state().guest_task.take().unwrap(); + + let caller = + if let Caller::Guest { task, .. } = &cx.concurrent_state().table.get(guest_task)?.caller { + Some(*task) + } else { + None + }; + cx.concurrent_state().guest_task = caller; + + let task = cx.concurrent_state().table.get_mut(guest_task)?; + + if guest_context != 0 { + log::trace!("set callback for {}", guest_task.rep()); + task.callback = Some(Callback { + function: callback.unwrap(), + instance: callee_instance, + context: guest_context, + }); + for (event, call, result) in mem::take(&mut task.events) { + cx = maybe_send_event(cx, guest_task, event, call, result)?; + } + } else if async_finished + && !(matches!(&task.caller, Caller::Guest {..} if task.result.is_some()) + || matches!(&task.caller, Caller::Host(tx) if tx.is_none())) + { + return Err(anyhow!(crate::Trap::NoAsyncResult)); + } + + Ok((guest_context, cx)) +} + +pub(crate) fn start_call<'a, T: Send, LowerParams: Copy, R: 'static>( + mut store: StoreContextMut<'a, T>, + lower_params: LowerFn, + lower_context: LiftLowerContext, + lift_result: LiftFn, + lift_context: LiftLowerContext, + handle: Func, +) -> Result<(Promise, StoreContextMut<'a, T>)> { + // TODO: Check to see if the callee is using the memory64 ABI, in which case we must use task_return_type64. + // How do we check that? + let func_data = &store.0[handle.0]; + let task_return_type = func_data.types[func_data.ty].results; + let is_concurrent = func_data.options.async_(); + let component_instance = func_data.component_instance; + let instance = func_data.instance; + let callee = func_data.export.func_ref; + let callback = func_data.options.callback; + let post_return = func_data.post_return; + + assert!(store.concurrent_state().guest_task.is_none()); + + // TODO: Can we safely leave this set? Can the same store be used with more than one ComponentInstance? Could + // we instead set this when the ConcurrentState is created so we don't have to set/unset it on the fly? + store.concurrent_state().component_instance = + Some(store.0[instance.0].as_ref().unwrap().state.ptr); + + let (tx, rx) = oneshot::channel(); + + let guest_task = store.concurrent_state().table.push(GuestTask { + lower_params: Some(Box::new(for_any_lower(move |store, params| { + lower_params(lower_context, store, params) + })) as RawLower), + lift_result: Some(( + Box::new(for_any_lift(move |store, result| { + lift_result(lift_context, store, result) + })) as RawLift, + task_return_type, + )), + caller: Caller::Host(Some(tx)), + ..GuestTask::default() + })?; + + log::trace!("starting call {}", guest_task.rep()); + + let instance = store.0[instance.0].as_ref().unwrap().instance_ptr(); + + let call = make_call( + guest_task, + SendSyncPtr::new(callee), + component_instance, + mem::size_of::() / mem::size_of::(), + 1, + if callback.is_none() { + None + } else { + Some(unsafe { (*instance).instance_flags(component_instance) }) + }, + ); + + store.concurrent_state().guest_task = Some(guest_task); + + store = do_start_call( + store, + instance, + guest_task, + is_concurrent, + call, + callback.map(SendSyncPtr::new), + post_return.map(|f| SendSyncPtr::new(f.func_ref)), + component_instance, + 1, + )? + .1; + + store.concurrent_state().guest_task = None; + + log::trace!("started call {}", guest_task.rep()); + + Ok(( + Promise(Box::pin( + rx.map(|result| *result.unwrap().downcast().unwrap()), + )), + store, + )) +} + +pub(crate) fn call<'a, T: Send, LowerParams: Copy, R: 'static>( + store: StoreContextMut<'a, T>, + lower_params: LowerFn, + lower_context: LiftLowerContext, + lift_result: LiftFn, + lift_context: LiftLowerContext, + handle: Func, +) -> Result<(R, StoreContextMut<'a, T>)> { + let (promise, mut store) = start_call::<_, LowerParams, R>( + store, + lower_params, + lower_context, + lift_result, + lift_context, + handle, + )?; + + let mut future = promise.into_future(); + let result = Arc::new(Mutex::new(None)); + store = poll_loop(store, { + let result = result.clone(); + move |store| { + let cx = AsyncCx::new(store); + let ready = unsafe { cx.poll(future.as_mut()) }; + Ok(match ready { + Poll::Ready(value) => { + *result.lock().unwrap() = Some(value); + false + } + Poll::Pending => true, + }) + } + })?; + + let result = result.lock().unwrap().take(); + if let Some(result) = result { + Ok((result, store)) + } else { + // All outstanding host tasks completed, but the guest never yielded a result. + Err(anyhow!(crate::Trap::NoAsyncResult)) + } +} + +pub(crate) async fn poll_until<'a, T: Send, U>( + mut store: StoreContextMut<'a, T>, + future: impl Future, +) -> Result<(StoreContextMut<'a, T>, U)> { + let mut future = Box::pin(future); + loop { + loop { + let mut ready = pin!(store.concurrent_state().futures.next()); + + let mut ready = future::poll_fn({ + move |cx| { + Poll::Ready(match ready.as_mut().poll(cx) { + Poll::Ready(Some(value)) => Some(value), + Poll::Ready(None) | Poll::Pending => None, + }) + } + }) + .await; + + if ready.is_some() { + store = poll_fn(store, (None, None), move |_, mut store| { + Ok(handle_ready(store.take().unwrap(), ready.take().unwrap())) + }) + .await?; + } else { + let (s, resumed) = poll_fn(store, (None, None), move |_, mut store| { + Ok(unyield(store.take().unwrap())) + }) + .await?; + store = s; + if !resumed { + break; + } + } + } + + let ready = pin!(store.concurrent_state().futures.next()); + + match future::select(ready, future).await { + Either::Left((None, future_again)) => break Ok((store, future_again.await)), + Either::Left((Some(ready), future_again)) => { + let mut ready = Some(ready); + store = poll_fn(store, (None, None), move |_, mut store| { + Ok(handle_ready(store.take().unwrap(), ready.take().unwrap())) + }) + .await?; + future = future_again; + } + Either::Right((result, _)) => break Ok((store, result)), + } + } +} + +async fn poll_fn<'a, T, R>( + mut store: StoreContextMut<'a, T>, + guard_range: (Option>, Option>), + mut fun: impl FnMut( + &mut Context, + Option>, + ) -> Result>>, +) -> R { + #[derive(Clone, Copy)] + struct PollCx(*mut PollContext); + + unsafe impl Send for PollCx {} + + let poll_cx = PollCx(store.concurrent_state().async_state.current_poll_cx.get()); + future::poll_fn({ + let mut store = Some(store); + + move |cx| unsafe { + let _reset = Reset(poll_cx.0, *poll_cx.0); + let guard_range_start = guard_range.0.map(|v| v.as_ptr()).unwrap_or(ptr::null_mut()); + let guard_range_end = guard_range.1.map(|v| v.as_ptr()).unwrap_or(ptr::null_mut()); + *poll_cx.0 = PollContext { + future_context: mem::transmute::<&mut Context<'_>, *mut Context<'static>>(cx), + guard_range_start, + guard_range_end, + }; + #[allow(dropping_copy_types)] + drop(poll_cx); + + match fun(cx, store.take()) { + Ok(v) => Poll::Ready(v), + Err(s) => { + store = s; + Poll::Pending + } + } + } + }) + .await +} diff --git a/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs b/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs index af407ddc6d4e..e81e535dbebb 100644 --- a/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs +++ b/crates/wasmtime/src/runtime/component/concurrent/futures_and_streams.rs @@ -1,14 +1,1945 @@ -use std::marker::PhantomData; +use { + super::{table::TableId, Event, GuestTask, HostTaskFuture, HostTaskResult, Promise}, + crate::{ + component::{ + func::{self, Lift, LiftContext, LowerContext, Options}, + matching::InstanceType, + values::{ErrorContextAny, FutureAny, StreamAny}, + Lower, Val, WasmList, WasmStr, + }, + vm::{ + component::{ + ComponentInstance, ErrorContextState, GlobalErrorContextRefCount, + LocalErrorContextRefCount, StateTable, StreamFutureState, WaitableState, + }, + SendSyncPtr, VMFuncRef, VMMemoryDefinition, VMStore, + }, + AsContextMut, StoreContextMut, ValRaw, + }, + anyhow::{anyhow, bail, Context, Result}, + futures::{ + channel::oneshot, + future::{self, FutureExt}, + }, + std::{ + any::Any, + boxed::Box, + marker::PhantomData, + mem::{self, MaybeUninit}, + ptr::NonNull, + string::ToString, + sync::Arc, + vec::Vec, + }, + wasmtime_environ::component::{ + CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, + TypeComponentGlobalErrorContextTableIndex, TypeComponentLocalErrorContextTableIndex, + TypeFutureTableIndex, TypeStreamTableIndex, + }, +}; + +const BLOCKED: usize = 0xffff_ffff; +const CLOSED: usize = 0x8000_0000; + +#[derive(Copy, Clone, Debug)] +pub(super) enum TableIndex { + Stream(TypeStreamTableIndex), + Future(TypeFutureTableIndex), +} + +fn payload(ty: TableIndex, types: &Arc) -> Option { + match ty { + TableIndex::Future(ty) => types[types[ty].ty].payload, + TableIndex::Stream(ty) => types[types[ty].ty].payload, + } +} + +fn state_table(instance: &mut ComponentInstance, ty: TableIndex) -> &mut StateTable { + let runtime_instance = match ty { + TableIndex::Stream(ty) => instance.component_types()[ty].instance, + TableIndex::Future(ty) => instance.component_types()[ty].instance, + }; + &mut instance.component_waitable_tables()[runtime_instance] +} + +fn push_event( + mut store: StoreContextMut, + rep: u32, + event: Event, + param: usize, + caller: TableId, +) { + store + .concurrent_state() + .futures + .get_mut() + .push(Box::pin(future::ready(( + rep, + Box::new(move |_| { + Ok(HostTaskResult { + event, + param: u32::try_from(param).unwrap(), + caller, + }) + }) + as Box Result + Send + Sync>, + ))) as HostTaskFuture); +} + +fn get_mut_by_index( + instance: &mut ComponentInstance, + ty: TableIndex, + index: u32, +) -> Result<(u32, &mut StreamFutureState)> { + get_mut_by_index_from(state_table(instance, ty), ty, index) +} + +fn get_mut_by_index_from( + state_table: &mut StateTable, + ty: TableIndex, + index: u32, +) -> Result<(u32, &mut StreamFutureState)> { + Ok(match ty { + TableIndex::Stream(ty) => { + let (rep, WaitableState::Stream(actual_ty, state)) = + state_table.get_mut_by_index(index)? + else { + bail!("invalid stream handle"); + }; + if *actual_ty != ty { + bail!("invalid stream handle"); + } + (rep, state) + } + TableIndex::Future(ty) => { + let (rep, WaitableState::Future(actual_ty, state)) = + state_table.get_mut_by_index(index)? + else { + bail!("invalid future handle"); + }; + if *actual_ty != ty { + bail!("invalid future handle"); + } + (rep, state) + } + }) +} + +fn waitable_state(ty: TableIndex, state: StreamFutureState) -> WaitableState { + match ty { + TableIndex::Stream(ty) => WaitableState::Stream(ty, state), + TableIndex::Future(ty) => WaitableState::Future(ty, state), + } +} + +fn accept( + values: Vec, + mut offset: usize, + transmit_id: TableId, + tx: oneshot::Sender<()>, +) -> impl FnOnce(Reader) -> Result + Send + Sync + 'static { + move |reader| { + let count = match reader { + Reader::Guest { + lower: + RawLowerContext { + store, + options, + types, + instance, + }, + ty, + address, + count, + } => { + let mut store = unsafe { StoreContextMut::(&mut *store.cast()) }; + let lower = &mut unsafe { + LowerContext::new(store.as_context_mut(), options, types, instance) + }; + if address % usize::try_from(T::ALIGN32)? != 0 { + bail!("read pointer not aligned"); + } + lower + .as_slice_mut() + .get_mut(address..) + .and_then(|b| b.get_mut(..T::SIZE32 * count)) + .ok_or_else(|| anyhow::anyhow!("read pointer out of bounds of memory"))?; + + let count = values.len().min(usize::try_from(count).unwrap()); + + if let Some(ty) = payload(ty, types) { + T::store_list(lower, ty, address, &values[offset..][..count])?; + } + offset += count; + + if offset < values.len() { + let transmit = store.concurrent_state().table.get_mut(transmit_id)?; + assert!(matches!(&transmit.write, WriteState::Open)); + + transmit.write = WriteState::HostReady { + accept: Box::new(accept::(values, offset, transmit_id, tx)), + close: false, + }; + } + + count + } + Reader::Host { accept } => { + assert!(offset == 0); // todo: do we need to handle offset != 0? + let count = values.len(); + accept(Box::new(values))?; + + count + } + Reader::None => 0, + }; + + Ok(count) + } +} + +fn host_write>( + mut store: S, + rep: u32, + values: Vec, + mut close: bool, +) -> Result> { + let mut store = store.as_context_mut(); + let (tx, rx) = oneshot::channel(); + let transmit_id = TableId::::new(rep); + let mut offset = 0; + + loop { + let transmit = store + .concurrent_state() + .table + .get_mut(transmit_id) + .with_context(|| rep.to_string())?; + let new_state = if let ReadState::Closed = &transmit.read { + ReadState::Closed + } else { + ReadState::Open + }; + + match mem::replace(&mut transmit.read, new_state) { + ReadState::Open => { + assert!(matches!(&transmit.write, WriteState::Open)); + + transmit.write = WriteState::HostReady { + accept: Box::new(accept::(values, offset, transmit_id, tx)), + close, + }; + close = false; + } + + ReadState::GuestReady { + ty, + flat_abi: _, + options, + address, + count, + instance, + handle, + caller, + } => unsafe { + let types = (*instance.as_ptr()).component_types(); + let lower = &mut LowerContext::new( + store.as_context_mut(), + &options, + types, + instance.as_ptr(), + ); + if address % usize::try_from(T::ALIGN32)? != 0 { + bail!("read pointer not aligned"); + } + lower + .as_slice_mut() + .get_mut(address..) + .and_then(|b| b.get_mut(..T::SIZE32 * count)) + .ok_or_else(|| anyhow::anyhow!("read pointer out of bounds of memory"))?; + + let count = values.len().min(count); + if let Some(ty) = payload(ty, types) { + T::store_list(lower, ty, address, &values[offset..][..count])?; + } + offset += count; + + log::trace!( + "remove read child of {}: {}", + caller.rep(), + transmit_id.rep() + ); + store + .concurrent_state() + .table + .remove_child(transmit_id, caller)?; + + *get_mut_by_index(&mut *instance.as_ptr(), ty, handle)?.1 = StreamFutureState::Read; + + push_event( + store.as_context_mut(), + transmit_id.rep(), + match ty { + TableIndex::Future(_) => Event::FutureRead, + TableIndex::Stream(_) => Event::StreamRead, + }, + count, + caller, + ); + + if offset < values.len() { + continue; + } + }, + + ReadState::HostReady { accept } => { + accept(Writer::Host { + values: Box::new(values), + })?; + } + + ReadState::Closed => {} + } + + if close { + host_close_writer(store, rep)?; + } + + break Ok(rx); + } +} + +pub fn host_read>( + mut store: S, + rep: u32, +) -> Result>>> { + let mut store = store.as_context_mut(); + let (tx, rx) = oneshot::channel(); + let transmit_id = TableId::::new(rep); + let transmit = store + .concurrent_state() + .table + .get_mut(transmit_id) + .with_context(|| rep.to_string())?; + let new_state = if let WriteState::Closed = &transmit.write { + WriteState::Closed + } else { + WriteState::Open + }; + + match mem::replace(&mut transmit.write, new_state) { + WriteState::Open => { + assert!(matches!(&transmit.read, ReadState::Open)); + + transmit.read = ReadState::HostReady { + accept: Box::new(move |writer| { + Ok(match writer { + Writer::Guest { + lift, + ty, + address, + count, + } => { + _ = tx.send( + ty.map(|ty| { + if address % usize::try_from(T::ALIGN32)? != 0 { + bail!("write pointer not aligned"); + } + lift.memory() + .get(address..) + .and_then(|b| b.get(..T::SIZE32 * count)) + .ok_or_else(|| { + anyhow::anyhow!("write pointer out of bounds of memory") + })?; + + let list = &WasmList::new(address, count, lift, ty)?; + T::load_list(lift, list) + }) + .transpose()?, + ); + count + } + Writer::Host { values } => { + let values = *values + .downcast::>() + .map_err(|_| anyhow!("transmit type mismatch"))?; + let count = values.len(); + _ = tx.send(Some(values)); + count + } + Writer::None => 0, + }) + }), + }; + } + + WriteState::GuestReady { + ty, + flat_abi: _, + options, + address, + count, + instance, + handle, + caller, + close, + } => unsafe { + let types = (*instance.as_ptr()).component_types(); + let lift = &mut LiftContext::new(store.0, &options, types, instance.as_ptr()); + _ = tx.send( + payload(ty, types) + .map(|ty| { + let list = &WasmList::new(address, count, lift, ty)?; + T::load_list(lift, list) + }) + .transpose()?, + ); + + log::trace!( + "remove write child of {}: {}", + caller.rep(), + transmit_id.rep() + ); + store + .concurrent_state() + .table + .remove_child(transmit_id, caller)?; + + if close { + store.concurrent_state().table.get_mut(transmit_id)?.write = WriteState::Closed; + } else { + *get_mut_by_index(&mut *instance.as_ptr(), ty, handle)?.1 = + StreamFutureState::Write; + } + + push_event( + store, + transmit_id.rep(), + match ty { + TableIndex::Future(_) => Event::FutureWrite, + TableIndex::Stream(_) => Event::StreamWrite, + }, + count, + caller, + ); + }, + + WriteState::HostReady { accept, close } => { + accept(Reader::Host { + accept: Box::new(move |any| { + _ = tx.send(Some( + *any.downcast() + .map_err(|_| anyhow!("transmit type mismatch"))?, + )); + Ok(()) + }), + })?; + + if close { + store.concurrent_state().table.get_mut(transmit_id)?.write = WriteState::Closed; + } + } + + WriteState::Closed => { + host_close_reader(store, rep)?; + } + } + + Ok(rx) +} + +fn host_cancel_write>(mut store: S, rep: u32) -> Result { + let mut store = store.as_context_mut(); + let transmit_id = TableId::::new(rep); + let transmit = store.concurrent_state().table.get_mut(transmit_id)?; + + match &transmit.write { + WriteState::GuestReady { caller, .. } => { + let caller = *caller; + transmit.write = WriteState::Open; + store + .concurrent_state() + .table + .remove_child(transmit_id, caller)?; + } + + WriteState::HostReady { .. } => { + transmit.write = WriteState::Open; + } + + WriteState::Open | WriteState::Closed => { + bail!("stream or future write canceled when no write is pending") + } + } + + log::trace!("canceled write {rep}"); + + Ok(0) +} + +fn host_cancel_read>(mut store: S, rep: u32) -> Result { + let mut store = store.as_context_mut(); + let transmit_id = TableId::::new(rep); + let transmit = store.concurrent_state().table.get_mut(transmit_id)?; + + match &transmit.read { + ReadState::GuestReady { caller, .. } => { + let caller = *caller; + transmit.read = ReadState::Open; + store + .concurrent_state() + .table + .remove_child(transmit_id, caller)?; + } + + ReadState::HostReady { .. } => { + transmit.read = ReadState::Open; + } + + ReadState::Open | ReadState::Closed => { + bail!("stream or future read canceled when no read is pending") + } + } + + log::trace!("canceled read {rep}"); + + Ok(0) +} + +fn host_close_writer>(mut store: S, rep: u32) -> Result<()> { + let mut store = store.as_context_mut(); + let transmit_id = TableId::::new(rep); + let transmit = store.concurrent_state().table.get_mut(transmit_id)?; + + match &mut transmit.write { + WriteState::GuestReady { close, .. } => { + *close = true; + } + + WriteState::HostReady { close, .. } => { + *close = true; + } + + v @ WriteState::Open => { + *v = WriteState::Closed; + } + + WriteState::Closed => unreachable!(), + } + + let new_state = if let ReadState::Closed = &transmit.read { + ReadState::Closed + } else { + ReadState::Open + }; + + match mem::replace(&mut transmit.read, new_state) { + ReadState::GuestReady { + ty, + instance, + handle, + caller, + .. + } => unsafe { + push_event( + store, + transmit_id.rep(), + match ty { + TableIndex::Future(_) => Event::FutureRead, + TableIndex::Stream(_) => Event::StreamRead, + }, + CLOSED, + caller, + ); + + *get_mut_by_index(&mut *instance.as_ptr(), ty, handle)?.1 = StreamFutureState::Read; + }, + + ReadState::HostReady { accept } => { + accept(Writer::None)?; + + host_close_reader(store, rep)?; + } + + ReadState::Open => {} + + ReadState::Closed => { + log::trace!("host_close_writer delete {}", transmit_id.rep()); + store.concurrent_state().table.delete(transmit_id)?; + } + } + Ok(()) +} + +fn host_close_reader>(mut store: S, rep: u32) -> Result<()> { + let mut store = store.as_context_mut(); + let transmit_id = TableId::::new(rep); + let transmit = store.concurrent_state().table.get_mut(transmit_id)?; + + transmit.read = ReadState::Closed; + + let new_state = if let WriteState::Closed = &transmit.write { + WriteState::Closed + } else { + WriteState::Open + }; + + match mem::replace(&mut transmit.write, new_state) { + WriteState::GuestReady { + ty, + instance, + handle, + close, + caller, + .. + } => unsafe { + push_event( + store.as_context_mut(), + transmit_id.rep(), + match ty { + TableIndex::Future(_) => Event::FutureRead, + TableIndex::Stream(_) => Event::StreamRead, + }, + CLOSED, + caller, + ); + + if close { + store.concurrent_state().table.delete(transmit_id)?; + } else { + *get_mut_by_index(&mut *instance.as_ptr(), ty, handle)?.1 = + StreamFutureState::Write; + } + }, + + WriteState::HostReady { accept, close } => { + accept(Reader::None)?; + + if close { + store.concurrent_state().table.delete(transmit_id)?; + } + } + + WriteState::Open => {} + + WriteState::Closed => { + log::trace!("host_close_reader delete {}", transmit_id.rep()); + store.concurrent_state().table.delete(transmit_id)?; + } + } + Ok(()) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) struct FlatAbi { + pub(super) size: u32, + pub(super) align: u32, +} + +/// Represents the writable end of a Component Model `future`. +pub struct FutureWriter { + rep: u32, + _phantom: PhantomData, +} + +impl FutureWriter { + /// Write the specified value to this `future`. + pub fn write>(self, store: S, value: T) -> Result> + where + T: func::Lower + Send + Sync + 'static, + { + Ok(Promise(Box::pin( + host_write(store, self.rep, vec![value], true)?.map(drop), + ))) + } + + /// Close this object without writing a value. + /// + /// If this object is dropped without calling either this method or `write`, + /// any read on the readable end will remain pending forever. + pub fn close>(self, store: S) -> Result<()> { + host_close_writer(store, self.rep) + } +} /// Represents the readable end of a Component Model `future`. pub struct FutureReader { + rep: u32, _phantom: PhantomData, } +impl FutureReader { + pub(crate) fn new(rep: u32) -> Self { + Self { + rep, + _phantom: PhantomData, + } + } + + /// Read the value from this `future`. + pub fn read>(self, store: S) -> Result>> + where + T: func::Lift + Sync + Send + 'static, + { + Ok(Promise(Box::pin(host_read(store, self.rep)?.map(|v| { + v.ok() + .and_then(|v| v.map(|v| v.into_iter().next().unwrap())) + })))) + } + + /// Convert this `FutureReader` into a [`Val`]. + pub fn into_val(self) -> Val { + Val::Future(FutureAny(self.rep)) + } + + /// Attempt to convert the specified [`Val`] to a `FutureReader`. + pub fn from_val>(mut store: S, value: &Val) -> Result { + let Val::Future(FutureAny(rep)) = value else { + bail!("expected `future`; got `{}`", value.desc()); + }; + store + .as_context_mut() + .concurrent_state() + .table + .get(TableId::::new(*rep))?; + Ok(Self::new(*rep)) + } + + fn lower_to_index(&self, cx: &mut LowerContext<'_, U>, ty: InterfaceType) -> Result { + match ty { + InterfaceType::Future(dst) => { + state_table(unsafe { &mut *cx.instance }, TableIndex::Future(dst)).insert( + self.rep, + WaitableState::Future(dst, StreamFutureState::Read), + ) + } + _ => func::bad_type_info(), + } + } + + fn lift_from_index(cx: &mut LiftContext<'_>, ty: InterfaceType, index: u32) -> Result { + match ty { + InterfaceType::Future(src) => { + let state_table = + state_table(unsafe { &mut *cx.instance }, TableIndex::Future(src)); + let (rep, state) = + get_mut_by_index_from(state_table, TableIndex::Future(src), index)?; + + match state { + StreamFutureState::Local => { + *state = StreamFutureState::Write; + } + StreamFutureState::Read => { + state_table.remove_by_index(index)?; + } + StreamFutureState::Write => bail!("cannot transfer write end of future"), + StreamFutureState::Busy => bail!("cannot transfer busy future"), + } + + Ok(Self { + rep, + _phantom: PhantomData, + }) + } + _ => func::bad_type_info(), + } + } + + /// Close this object without reading the value. + /// + /// If this object is dropped without calling either this method or `read`, + /// any write on the writable end will remain pending forever. + pub fn close>(self, store: S) -> Result<()> { + host_close_reader(store, self.rep) + } +} + +unsafe impl func::ComponentType for FutureReader { + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4; + + type Lower = ::Lower; + + fn typecheck(ty: &InterfaceType, _types: &InstanceType<'_>) -> Result<()> { + match ty { + InterfaceType::Future(_) => Ok(()), + other => bail!("expected `future`, found `{}`", func::desc(other)), + } + } +} + +unsafe impl func::Lower for FutureReader { + fn lower( + &self, + cx: &mut LowerContext<'_, U>, + ty: InterfaceType, + dst: &mut MaybeUninit, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .lower(cx, InterfaceType::U32, dst) + } + + fn store( + &self, + cx: &mut LowerContext<'_, U>, + ty: InterfaceType, + offset: usize, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .store(cx, InterfaceType::U32, offset) + } +} + +unsafe impl func::Lift for FutureReader { + fn lift(cx: &mut LiftContext<'_>, ty: InterfaceType, src: &Self::Lower) -> Result { + let index = u32::lift(cx, InterfaceType::U32, src)?; + Self::lift_from_index(cx, ty, index) + } + + fn load(cx: &mut LiftContext<'_>, ty: InterfaceType, bytes: &[u8]) -> Result { + let index = u32::load(cx, InterfaceType::U32, bytes)?; + Self::lift_from_index(cx, ty, index) + } +} + +/// Create a new Component Model `future` as pair of writable and readable ends, +/// the latter of which may be passed to guest code. +pub fn future>( + mut store: S, +) -> Result<(FutureWriter, FutureReader)> { + let mut store = store.as_context_mut(); + let transmit = store.concurrent_state().table.push(TransmitState { + read: ReadState::Open, + write: WriteState::Open, + })?; + + Ok(( + FutureWriter { + rep: transmit.rep(), + _phantom: PhantomData, + }, + FutureReader { + rep: transmit.rep(), + _phantom: PhantomData, + }, + )) +} + +/// Represents the writable end of a Component Model `stream`. +pub struct StreamWriter { + rep: u32, + _phantom: PhantomData, +} + +impl StreamWriter { + /// Write the specified values to the `stream`. + pub fn write>( + self, + store: S, + values: Vec, + ) -> Result>> + where + T: func::Lower + Send + Sync + 'static, + { + Ok(Promise(Box::pin( + host_write(store, self.rep, values, false)?.map(move |_| self), + ))) + } + + /// Close this object without writing any more values. + /// + /// If this object is dropped without calling this method, any read on the + /// readable end will remain pending forever. + pub fn close>(self, store: S) -> Result<()> { + host_close_writer(store, self.rep) + } +} + /// Represents the readable end of a Component Model `stream`. pub struct StreamReader { + rep: u32, _phantom: PhantomData, } +impl StreamReader { + pub(crate) fn new(rep: u32) -> Self { + Self { + rep, + _phantom: PhantomData, + } + } + + /// Read the next values (if any) from this `stream`. + pub fn read>( + self, + store: S, + ) -> Result, Vec)>>> + where + T: func::Lift + Sync + Send + 'static, + { + Ok(Promise(Box::pin( + host_read(store, self.rep)?.map(move |v| v.ok().and_then(|v| v.map(|v| (self, v)))), + ))) + } + + /// Convert this `StreamReader` into a [`Val`]. + pub fn into_val(self) -> Val { + Val::Stream(StreamAny(self.rep)) + } + + /// Attempt to convert the specified [`Val`] to a `StreamReader`. + pub fn from_val>(mut store: S, value: &Val) -> Result { + let Val::Stream(StreamAny(rep)) = value else { + bail!("expected `stream`; got `{}`", value.desc()); + }; + store + .as_context_mut() + .concurrent_state() + .table + .get(TableId::::new(*rep))?; + Ok(Self::new(*rep)) + } + + fn lower_to_index(&self, cx: &mut LowerContext<'_, U>, ty: InterfaceType) -> Result { + match ty { + InterfaceType::Stream(dst) => { + state_table(unsafe { &mut *cx.instance }, TableIndex::Stream(dst)).insert( + self.rep, + WaitableState::Stream(dst, StreamFutureState::Read), + ) + } + _ => func::bad_type_info(), + } + } + + fn lift_from_index(cx: &mut LiftContext<'_>, ty: InterfaceType, index: u32) -> Result { + match ty { + InterfaceType::Stream(src) => { + let state_table = + state_table(unsafe { &mut *cx.instance }, TableIndex::Stream(src)); + let (rep, state) = + get_mut_by_index_from(state_table, TableIndex::Stream(src), index)?; + + match state { + StreamFutureState::Local => { + *state = StreamFutureState::Write; + } + StreamFutureState::Read => { + state_table.remove_by_index(index)?; + } + StreamFutureState::Write => bail!("cannot transfer write end of stream"), + StreamFutureState::Busy => bail!("cannot transfer busy stream"), + } + + Ok(Self { + rep, + _phantom: PhantomData, + }) + } + _ => func::bad_type_info(), + } + } + + /// Close this object without reading any more values. + /// + /// If the object is dropped without either calling this method or reading + /// until the end of the stream, any write on the writable end will remain + /// pending forever. + pub fn close>(self, store: S) -> Result<()> { + host_close_reader(store, self.rep) + } +} + +unsafe impl func::ComponentType for StreamReader { + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4; + + type Lower = ::Lower; + + fn typecheck(ty: &InterfaceType, _types: &InstanceType<'_>) -> Result<()> { + match ty { + InterfaceType::Stream(_) => Ok(()), + other => bail!("expected `stream`, found `{}`", func::desc(other)), + } + } +} + +unsafe impl func::Lower for StreamReader { + fn lower( + &self, + cx: &mut LowerContext<'_, U>, + ty: InterfaceType, + dst: &mut MaybeUninit, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .lower(cx, InterfaceType::U32, dst) + } + + fn store( + &self, + cx: &mut LowerContext<'_, U>, + ty: InterfaceType, + offset: usize, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .store(cx, InterfaceType::U32, offset) + } +} + +unsafe impl func::Lift for StreamReader { + fn lift(cx: &mut LiftContext<'_>, ty: InterfaceType, src: &Self::Lower) -> Result { + let index = u32::lift(cx, InterfaceType::U32, src)?; + Self::lift_from_index(cx, ty, index) + } + + fn load(cx: &mut LiftContext<'_>, ty: InterfaceType, bytes: &[u8]) -> Result { + let index = u32::load(cx, InterfaceType::U32, bytes)?; + Self::lift_from_index(cx, ty, index) + } +} + +/// Create a new Component Model `stream` as pair of writable and readable ends, +/// the latter of which may be passed to guest code. +pub fn stream>( + mut store: S, +) -> Result<(StreamWriter, StreamReader)> { + let mut store = store.as_context_mut(); + let transmit = store.concurrent_state().table.push(TransmitState { + read: ReadState::Open, + write: WriteState::Open, + })?; + + Ok(( + StreamWriter { + rep: transmit.rep(), + _phantom: PhantomData, + }, + StreamReader { + rep: transmit.rep(), + _phantom: PhantomData, + }, + )) +} + /// Represents a Component Model `error-context`. -pub struct ErrorContext {} +pub struct ErrorContext { + rep: u32, +} + +impl ErrorContext { + pub(crate) fn new(rep: u32) -> Self { + Self { rep } + } + + /// Convert this `ErrorContext` into a [`Val`]. + pub fn into_val(self) -> Val { + Val::ErrorContext(ErrorContextAny(self.rep)) + } + + /// Attempt to convert the specified [`Val`] to a `ErrorContext`. + pub fn from_val>(_: S, value: &Val) -> Result { + let Val::ErrorContext(ErrorContextAny(rep)) = value else { + bail!("expected `error-context`; got `{}`", value.desc()); + }; + Ok(Self::new(*rep)) + } + + fn lower_to_index(&self, cx: &mut LowerContext<'_, U>, ty: InterfaceType) -> Result { + match ty { + InterfaceType::ErrorContext(dst) => { + let tbl = unsafe { + &mut (*cx.instance) + .component_error_context_tables() + .get_mut(dst) + .expect("error context table index present in (sub)component table during lower") + }; + + if let Some((dst_idx, dst_state)) = tbl.get_mut_by_rep(self.rep) { + dst_state.0 += 1; + Ok(dst_idx) + } else { + tbl.insert(self.rep, LocalErrorContextRefCount(1)) + } + } + _ => func::bad_type_info(), + } + } + + fn lift_from_index(cx: &mut LiftContext<'_>, ty: InterfaceType, index: u32) -> Result { + match ty { + InterfaceType::ErrorContext(src) => { + let (rep, _) = unsafe { + (*cx.instance) + .component_error_context_tables() + .get_mut(src) + .expect( + "error context table index present in (sub)component table during lift", + ) + .get_mut_by_index(index)? + }; + + Ok(Self { rep }) + } + _ => func::bad_type_info(), + } + } +} + +unsafe impl func::ComponentType for ErrorContext { + const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4; + + type Lower = ::Lower; + + fn typecheck(ty: &InterfaceType, _types: &InstanceType<'_>) -> Result<()> { + match ty { + InterfaceType::ErrorContext(_) => Ok(()), + other => bail!("expected `error`, found `{}`", func::desc(other)), + } + } +} + +unsafe impl func::Lower for ErrorContext { + fn lower( + &self, + cx: &mut LowerContext<'_, T>, + ty: InterfaceType, + dst: &mut MaybeUninit, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .lower(cx, InterfaceType::U32, dst) + } + + fn store( + &self, + cx: &mut LowerContext<'_, T>, + ty: InterfaceType, + offset: usize, + ) -> Result<()> { + self.lower_to_index(cx, ty)? + .store(cx, InterfaceType::U32, offset) + } +} + +unsafe impl func::Lift for ErrorContext { + fn lift(cx: &mut LiftContext<'_>, ty: InterfaceType, src: &Self::Lower) -> Result { + let index = u32::lift(cx, InterfaceType::U32, src)?; + Self::lift_from_index(cx, ty, index) + } + + fn load(cx: &mut LiftContext<'_>, ty: InterfaceType, bytes: &[u8]) -> Result { + let index = u32::load(cx, InterfaceType::U32, bytes)?; + Self::lift_from_index(cx, ty, index) + } +} + +pub(super) struct TransmitState { + write: WriteState, + read: ReadState, +} + +enum WriteState { + Open, + GuestReady { + ty: TableIndex, + flat_abi: Option, + options: Options, + address: usize, + count: usize, + instance: SendSyncPtr, + handle: u32, + caller: TableId, + close: bool, + }, + HostReady { + accept: Box Result + Send + Sync>, + close: bool, + }, + Closed, +} + +enum ReadState { + Open, + GuestReady { + ty: TableIndex, + flat_abi: Option, + options: Options, + address: usize, + count: usize, + instance: SendSyncPtr, + handle: u32, + caller: TableId, + }, + HostReady { + accept: Box Result + Send + Sync>, + }, + Closed, +} + +enum Writer<'a> { + Guest { + lift: &'a mut LiftContext<'a>, + ty: Option, + address: usize, + count: usize, + }, + Host { + values: Box, + }, + None, +} + +struct RawLowerContext<'a> { + store: *mut dyn VMStore, + options: &'a Options, + types: &'a Arc, + instance: *mut ComponentInstance, +} + +enum Reader<'a> { + Guest { + lower: RawLowerContext<'a>, + ty: TableIndex, + address: usize, + count: usize, + }, + Host { + accept: Box) -> Result<()>>, + }, + None, +} + +pub(super) fn guest_new( + mut cx: StoreContextMut, + instance: &mut ComponentInstance, + ty: TableIndex, +) -> Result { + let transmit = cx.concurrent_state().table.push(TransmitState { + read: ReadState::Open, + write: WriteState::Open, + })?; + state_table(instance, ty).insert(transmit.rep(), waitable_state(ty, StreamFutureState::Local)) +} + +fn copy( + mut cx: StoreContextMut<'_, T>, + types: &Arc, + instance: *mut ComponentInstance, + flat_abi: Option, + write_ty: TableIndex, + write_options: &Options, + write_address: usize, + read_ty: TableIndex, + read_options: &Options, + read_address: usize, + count: usize, + rep: u32, +) -> Result<()> { + match (write_ty, read_ty) { + (TableIndex::Future(write_ty), TableIndex::Future(read_ty)) => { + assert_eq!(count, 1); + + let val = types[types[write_ty].ty] + .payload + .map(|ty| { + let abi = types.canonical_abi(&ty); + // FIXME: needs to read an i64 for memory64 + if write_address % usize::try_from(abi.align32)? != 0 { + bail!("write pointer not aligned"); + } + + let lift = + &mut unsafe { LiftContext::new(cx.0, write_options, types, instance) }; + + let bytes = lift + .memory() + .get(write_address..) + .and_then(|b| b.get(..usize::try_from(abi.size32).unwrap())) + .ok_or_else(|| anyhow::anyhow!("write pointer out of bounds of memory"))?; + + Val::load(lift, ty, bytes) + }) + .transpose()?; + + if let Some(val) = val { + let mut lower = unsafe { + LowerContext::new(cx.as_context_mut(), read_options, types, instance) + }; + let ty = types[types[read_ty].ty].payload.unwrap(); + let ptr = func::validate_inbounds_dynamic( + types.canonical_abi(&ty), + lower.as_slice_mut(), + &ValRaw::u32(read_address.try_into().unwrap()), + )?; + val.store(&mut lower, ty, ptr)?; + } + } + (TableIndex::Stream(write_ty), TableIndex::Stream(read_ty)) => { + let lift = &mut unsafe { LiftContext::new(cx.0, write_options, types, instance) }; + if let Some(flat_abi) = flat_abi { + // Fast path memcpy for "flat" (i.e. no pointers or handles) payloads: + let length_in_bytes = usize::try_from(flat_abi.size).unwrap() * count; + if length_in_bytes > 0 { + if write_address % usize::try_from(flat_abi.align)? != 0 { + bail!("write pointer not aligned"); + } + if read_address % usize::try_from(flat_abi.align)? != 0 { + bail!("read pointer not aligned"); + } + + { + let src = write_options + .memory(cx.0) + .get(write_address..) + .and_then(|b| b.get(..length_in_bytes)) + .ok_or_else(|| { + anyhow::anyhow!("write pointer out of bounds of memory") + })? + .as_ptr(); + let dst = read_options + .memory_mut(cx.0) + .get_mut(read_address..) + .and_then(|b| b.get_mut(..length_in_bytes)) + .ok_or_else(|| anyhow::anyhow!("read pointer out of bounds of memory"))? + .as_mut_ptr(); + unsafe { src.copy_to(dst, length_in_bytes) }; + } + } + } else { + let ty = types[types[write_ty].ty].payload.unwrap(); + let abi = lift.types.canonical_abi(&ty); + let size = usize::try_from(abi.size32).unwrap(); + if write_address % usize::try_from(abi.align32)? != 0 { + bail!("write pointer not aligned"); + } + let bytes = lift + .memory() + .get(write_address..) + .and_then(|b| b.get(..size * count)) + .ok_or_else(|| anyhow::anyhow!("write pointer out of bounds of memory"))?; + + let values = (0..count) + .map(|index| Val::load(lift, ty, &bytes[(index * size)..][..size])) + .collect::>>()?; + + log::trace!("copy values {values:?} for {rep}"); + + let lower = &mut unsafe { + LowerContext::new(cx.as_context_mut(), read_options, types, instance) + }; + let ty = types[types[read_ty].ty].payload.unwrap(); + let abi = lower.types.canonical_abi(&ty); + if read_address % usize::try_from(abi.align32)? != 0 { + bail!("read pointer not aligned"); + } + let size = usize::try_from(abi.size32).unwrap(); + lower + .as_slice_mut() + .get_mut(read_address..) + .and_then(|b| b.get_mut(..size * count)) + .ok_or_else(|| anyhow::anyhow!("read pointer out of bounds of memory"))?; + let mut ptr = read_address; + for value in values { + value.store(lower, ty, ptr)?; + ptr += size + } + } + } + _ => unreachable!(), + } + + Ok(()) +} + +pub(super) fn guest_write( + mut cx: StoreContextMut, + instance: *mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TableIndex, + flat_abi: Option, + handle: u32, + address: u32, + count: u32, +) -> Result { + let address = usize::try_from(address).unwrap(); + let count = usize::try_from(count).unwrap(); + let options = unsafe { + Options::new( + cx.0.id(), + NonNull::new(memory), + NonNull::new(realloc), + StringEncoding::from_u8(string_encoding).unwrap(), + true, + None, + ) + }; + let types = unsafe { (*instance).component_types() }; + let (rep, state) = unsafe { get_mut_by_index(&mut *instance, ty, handle)? }; + let StreamFutureState::Write = *state else { + bail!("invalid handle"); + }; + *state = StreamFutureState::Busy; + let transmit_id = TableId::::new(rep); + let transmit = cx.concurrent_state().table.get_mut(transmit_id)?; + let new_state = if let ReadState::Closed = &transmit.read { + ReadState::Closed + } else { + ReadState::Open + }; + + let result = match mem::replace(&mut transmit.read, new_state) { + ReadState::GuestReady { + ty: read_ty, + flat_abi: read_flat_abi, + options: read_options, + address: read_address, + count: read_count, + instance: _, + handle: read_handle, + caller: read_caller, + } => { + assert_eq!(flat_abi, read_flat_abi); + + let count = count.min(read_count); + + copy( + cx.as_context_mut(), + types, + instance, + flat_abi, + ty, + &options, + address, + read_ty, + &read_options, + read_address, + count, + rep, + )?; + + log::trace!( + "remove read child of {}: {}", + read_caller.rep(), + transmit_id.rep() + ); + cx.concurrent_state() + .table + .remove_child(transmit_id, read_caller)?; + + unsafe { + *get_mut_by_index(&mut *instance, read_ty, read_handle)?.1 = + StreamFutureState::Read; + } + + push_event( + cx, + transmit_id.rep(), + match read_ty { + TableIndex::Future(_) => Event::FutureRead, + TableIndex::Stream(_) => Event::StreamRead, + }, + count, + read_caller, + ); + + count + } + + ReadState::HostReady { accept } => { + let lift = &mut unsafe { LiftContext::new(cx.0, &options, types, instance) }; + accept(Writer::Guest { + lift, + ty: payload(ty, types), + address, + count, + })? + } + + ReadState::Open => { + assert!(matches!(&transmit.write, WriteState::Open)); + + let caller = cx.concurrent_state().guest_task.unwrap(); + log::trace!( + "add write {} child of {}: {}", + match ty { + TableIndex::Future(_) => "future", + TableIndex::Stream(_) => "stream", + }, + caller.rep(), + transmit_id.rep() + ); + cx.concurrent_state().table.add_child(transmit_id, caller)?; + + let transmit = cx.concurrent_state().table.get_mut(transmit_id)?; + transmit.write = WriteState::GuestReady { + ty, + flat_abi, + options, + address: usize::try_from(address).unwrap(), + count: usize::try_from(count).unwrap(), + instance: SendSyncPtr::new(NonNull::new(instance).unwrap()), + handle, + caller, + close: false, + }; + + BLOCKED + } + + ReadState::Closed => CLOSED, + }; + + if result != BLOCKED { + unsafe { + *get_mut_by_index(&mut *instance, ty, handle)?.1 = StreamFutureState::Write; + } + } + + Ok(u32::try_from(result).unwrap()) +} + +pub(super) fn guest_read( + mut cx: StoreContextMut, + instance: *mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TableIndex, + flat_abi: Option, + handle: u32, + address: u32, + count: u32, +) -> Result { + let address = usize::try_from(address).unwrap(); + let count = usize::try_from(count).unwrap(); + let options = unsafe { + Options::new( + cx.0.id(), + NonNull::new(memory), + NonNull::new(realloc), + StringEncoding::from_u8(string_encoding).unwrap(), + true, + None, + ) + }; + let types = unsafe { (*instance).component_types() }; + let (rep, state) = unsafe { get_mut_by_index(&mut *instance, ty, handle)? }; + let StreamFutureState::Read = *state else { + bail!("invalid handle"); + }; + *state = StreamFutureState::Busy; + let transmit_id = TableId::::new(rep); + let transmit = cx.concurrent_state().table.get_mut(transmit_id)?; + let new_state = if let WriteState::Closed = &transmit.write { + WriteState::Closed + } else { + WriteState::Open + }; + + let result = match mem::replace(&mut transmit.write, new_state) { + WriteState::GuestReady { + ty: write_ty, + flat_abi: write_flat_abi, + options: write_options, + address: write_address, + count: write_count, + instance: _, + handle: write_handle, + caller: write_caller, + close, + } => { + assert_eq!(flat_abi, write_flat_abi); + + let count = count.min(write_count); + + copy( + cx.as_context_mut(), + types, + instance, + flat_abi, + write_ty, + &write_options, + write_address, + ty, + &options, + address, + count, + rep, + )?; + + log::trace!( + "remove write child of {}: {}", + write_caller.rep(), + transmit_id.rep() + ); + cx.concurrent_state() + .table + .remove_child(transmit_id, write_caller)?; + + if close { + cx.concurrent_state().table.get_mut(transmit_id)?.write = WriteState::Closed; + } else { + unsafe { + *get_mut_by_index(&mut *instance, write_ty, write_handle)?.1 = + StreamFutureState::Write; + } + } + + push_event( + cx, + transmit_id.rep(), + match write_ty { + TableIndex::Future(_) => Event::FutureWrite, + TableIndex::Stream(_) => Event::StreamWrite, + }, + count, + write_caller, + ); + + count + } + + WriteState::HostReady { accept, close } => { + let count = accept(Reader::Guest { + lower: RawLowerContext { + store: cx.0.traitobj().as_ptr(), + options: &options, + types, + instance, + }, + ty, + address: usize::try_from(address).unwrap(), + count, + })?; + + if close { + cx.concurrent_state().table.get_mut(transmit_id)?.write = WriteState::Closed; + } + + count + } + + WriteState::Open => { + assert!(matches!(&transmit.read, ReadState::Open)); + + let caller = cx.concurrent_state().guest_task.unwrap(); + log::trace!( + "add read {} child of {}: {}", + match ty { + TableIndex::Future(_) => "future", + TableIndex::Stream(_) => "stream", + }, + caller.rep(), + transmit_id.rep() + ); + cx.concurrent_state().table.add_child(transmit_id, caller)?; + + let transmit = cx.concurrent_state().table.get_mut(transmit_id)?; + transmit.read = ReadState::GuestReady { + ty, + flat_abi, + options, + address: usize::try_from(address).unwrap(), + count: usize::try_from(count).unwrap(), + instance: SendSyncPtr::new(NonNull::new(instance).unwrap()), + handle, + caller, + }; + + BLOCKED + } + + WriteState::Closed => CLOSED, + }; + + if result != BLOCKED { + unsafe { + *get_mut_by_index(&mut *instance, ty, handle)?.1 = StreamFutureState::Read; + } + } + + Ok(u32::try_from(result).unwrap()) +} + +pub(super) fn guest_cancel_write( + cx: StoreContextMut, + instance: &mut ComponentInstance, + ty: TableIndex, + writer: u32, + _async_: bool, +) -> Result { + let (rep, WaitableState::Stream(_, state) | WaitableState::Future(_, state)) = + state_table(instance, ty).get_mut_by_index(writer)? + else { + bail!("invalid stream or future handle"); + }; + match state { + StreamFutureState::Local | StreamFutureState::Write => { + bail!("stream or future write canceled when no write is pending") + } + StreamFutureState::Read => { + bail!("passed read end to `{{stream|future}}.cancel-write`") + } + StreamFutureState::Busy => { + *state = StreamFutureState::Write; + } + } + host_cancel_write(cx, rep) +} + +pub(super) fn guest_cancel_read( + cx: StoreContextMut, + instance: &mut ComponentInstance, + ty: TableIndex, + reader: u32, + _async_: bool, +) -> Result { + let (rep, WaitableState::Stream(_, state) | WaitableState::Future(_, state)) = + state_table(instance, ty).get_mut_by_index(reader)? + else { + bail!("invalid stream or future handle"); + }; + match state { + StreamFutureState::Local | StreamFutureState::Read => { + bail!("stream or future read canceled when no read is pending") + } + StreamFutureState::Write => { + bail!("passed write end to `{{stream|future}}.cancel-read`") + } + StreamFutureState::Busy => { + *state = StreamFutureState::Read; + } + } + host_cancel_read(cx, rep) +} + +pub(super) fn guest_close_writable( + cx: StoreContextMut, + instance: &mut ComponentInstance, + ty: TableIndex, + writer: u32, + error: u32, +) -> Result<()> { + if error != 0 { + bail!("todo: closing writable streams and futures with errors not yet implemented"); + } + + let (rep, WaitableState::Stream(_, state) | WaitableState::Future(_, state)) = + state_table(instance, ty).remove_by_index(writer)? + else { + bail!("invalid stream or future handle"); + }; + match state { + StreamFutureState::Local | StreamFutureState::Write => {} + StreamFutureState::Read => { + bail!("passed read end to `{{stream|future}}.close-writable`") + } + StreamFutureState::Busy => bail!("cannot drop busy stream or future"), + } + host_close_writer(cx, rep) +} + +pub(super) fn guest_close_readable( + cx: StoreContextMut, + instance: &mut ComponentInstance, + ty: TableIndex, + reader: u32, +) -> Result<()> { + let (rep, WaitableState::Stream(_, state) | WaitableState::Future(_, state)) = + state_table(instance, ty).remove_by_index(reader)? + else { + bail!("invalid stream or future handle"); + }; + match state { + StreamFutureState::Local | StreamFutureState::Read => {} + StreamFutureState::Write => { + bail!("passed write end to `{{stream|future}}.close-readable`") + } + StreamFutureState::Busy => bail!("cannot drop busy stream or future"), + } + host_close_reader(cx, rep) +} + +/// Create a new error context for the given component +pub(super) fn error_context_new( + mut cx: StoreContextMut, + instance: *mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeComponentLocalErrorContextTableIndex, + debug_msg_address: u32, + debug_msg_len: u32, +) -> Result { + // Read string from guest memory + let options = unsafe { + Options::new( + cx.0.id(), + NonNull::new(memory), + NonNull::new(realloc), + StringEncoding::from_u8(string_encoding).ok_or_else(|| { + anyhow::anyhow!("failed to convert u8 string encoding [{string_encoding}]") + })?, + false, + None, + ) + }; + let lift_ctx = + &mut unsafe { LiftContext::new(cx.0, &options, (*instance).component_types(), instance) }; + let s = { + let address = usize::try_from(debug_msg_address)?; + let len = usize::try_from(debug_msg_len)?; + WasmStr::load( + lift_ctx, + InterfaceType::String, + &lift_ctx + .memory() + .get(address..) + .and_then(|b| b.get(..len)) + .map(|_| [debug_msg_address.to_le_bytes(), debug_msg_len.to_le_bytes()].concat()) + .ok_or_else(|| anyhow::anyhow!("invalid debug message pointer: out of bounds"))?, + )? + }; + + // Create a new ErrorContext that is tracked along with other concurrent state + let err_ctx = ErrorContextState { + debug_msg: s.to_str(&cx)?.to_string(), + }; + let table_id = cx.concurrent_state().table.push(err_ctx)?; + let global_ref_count_idx = TypeComponentGlobalErrorContextTableIndex::from_u32(table_id.rep()); + + // Add to the global error context ref counts + unsafe { + let _ = (*instance) + .component_global_error_context_ref_counts() + .insert(global_ref_count_idx, GlobalErrorContextRefCount(1)); + } + + // Error context are tracked both locally (to a single component instance) and globally + // the counts for both must stay in sync. + // + // Here we reflect the newly created global concurrent error context state into the + // component instance's locally tracked count, along with the appropriate key into the global + // ref tracking data structures to enable later lookup + let local_tbl = unsafe { + (*instance) + .component_error_context_tables() + .get_mut_or_insert_with(ty, || StateTable::default()) + }; + assert!( + !local_tbl.has_handle(table_id.rep()), + "newly created error context state already tracked by component" + ); + let local_idx = local_tbl.insert(table_id.rep(), LocalErrorContextRefCount(1))?; + + Ok(local_idx) +} + +pub(super) fn error_context_debug_message( + mut cx: StoreContextMut, + instance: *mut ComponentInstance, + memory: *mut VMMemoryDefinition, + realloc: *mut VMFuncRef, + string_encoding: u8, + ty: TypeComponentLocalErrorContextTableIndex, + err_ctx_handle: u32, + debug_msg_address: u32, +) -> Result<()> { + let store_id = cx.0.id(); + + // Retrieve the error context and internal debug message + let (state_table_id_rep, _) = unsafe { + (*instance) + .component_error_context_tables() + .get_mut(ty) + .context("error context table index present in (sub)component lookup during debug_msg")? + .get_mut_by_index(err_ctx_handle)? + }; + + // Get the state associated with the error context + let ErrorContextState { debug_msg } = + cx.concurrent_state() + .table + .get_mut(TableId::::new(state_table_id_rep))?; + let debug_msg = debug_msg.clone(); + + // Lower the string into the component's memory + let options = unsafe { + Options::new( + store_id, + NonNull::new(memory), + NonNull::new(realloc), + StringEncoding::from_u8(string_encoding).ok_or_else(|| { + anyhow::anyhow!("failed to convert u8 string encoding [{string_encoding}]") + })?, + false, + None, + ) + }; + let lower_cx = + &mut unsafe { LowerContext::new(cx, &options, (*instance).component_types(), instance) }; + let debug_msg_address = usize::try_from(debug_msg_address)?; + let offset = lower_cx + .as_slice_mut() + .get(debug_msg_address..) + .and_then(|b| b.get(..debug_msg.bytes().len())) + .map(|_| debug_msg_address) + .ok_or_else(|| anyhow::anyhow!("invalid debug message pointer: out of bounds"))?; + debug_msg + .as_str() + .store(lower_cx, InterfaceType::String, offset)?; + + Ok(()) +} + +pub(super) fn error_context_drop( + mut cx: StoreContextMut, + instance: &mut ComponentInstance, + ty: TypeComponentLocalErrorContextTableIndex, + error_context: u32, +) -> Result<()> { + let local_state_table = instance + .component_error_context_tables() + .get_mut(ty) + .context("error context table index present in (sub)component table during drop")?; + + // Reduce the local (sub)component ref count, removing tracking if necessary + let (rep, local_ref_removed) = { + let (rep, LocalErrorContextRefCount(local_ref_count)) = + local_state_table.get_mut_by_index(error_context)?; + assert!(*local_ref_count > 0); + *local_ref_count -= 1; + let mut local_ref_removed = false; + if *local_ref_count == 0 { + local_ref_removed = true; + local_state_table + .remove_by_index(error_context) + .context("removing error context from component-local tracking")?; + } + (rep, local_ref_removed) + }; + + let global_ref_count_idx = TypeComponentGlobalErrorContextTableIndex::from_u32(rep); + + let GlobalErrorContextRefCount(global_ref_count) = instance + .component_global_error_context_ref_counts() + .get_mut(&global_ref_count_idx) + .expect("retrieve concurrent state for error context during drop"); + + // Reduce the component-global ref count, removing tracking if necessary + assert!(*global_ref_count >= 1); + *global_ref_count -= 1; + if *global_ref_count == 0 { + assert!(local_ref_removed); + + instance + .component_global_error_context_ref_counts() + .remove(&global_ref_count_idx); + + cx.concurrent_state() + .table + .delete(TableId::::new(rep)) + .context("deleting component-global error context data")?; + } + + Ok(()) +} diff --git a/crates/wasmtime/src/runtime/component/concurrent/ready_chunks.rs b/crates/wasmtime/src/runtime/component/concurrent/ready_chunks.rs new file mode 100644 index 000000000000..f82bddcee4c7 --- /dev/null +++ b/crates/wasmtime/src/runtime/component/concurrent/ready_chunks.rs @@ -0,0 +1,59 @@ +//! Like `futures::stream::ReadyChunks` but without fusing the inner stream. +//! +//! We use this with `FuturesUnordered` which may produce `Poll::Ready(None)` but later produce more elements due +//! to additional futures having been added, so fusing is not appropriate. + +use { + futures::{Stream, StreamExt}, + std::{ + pin::Pin, + task::{Context, Poll}, + vec::Vec, + }, +}; + +pub struct ReadyChunks { + stream: S, + capacity: usize, +} + +impl ReadyChunks { + pub fn new(stream: S, capacity: usize) -> Self { + Self { stream, capacity } + } + + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } +} + +impl Stream for ReadyChunks { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut items = Vec::new(); + + loop { + match self.stream.poll_next_unpin(cx) { + Poll::Pending => { + break if items.is_empty() { + Poll::Pending + } else { + Poll::Ready(Some(items)) + } + } + + Poll::Ready(Some(item)) => { + items.push(item); + if items.len() >= self.capacity { + break Poll::Ready(Some(items)); + } + } + + Poll::Ready(None) => { + break Poll::Ready(if items.is_empty() { None } else { Some(items) }); + } + } + } + } +} diff --git a/crates/wasmtime/src/runtime/component/concurrent/table.rs b/crates/wasmtime/src/runtime/component/concurrent/table.rs new file mode 100644 index 000000000000..a609052244bf --- /dev/null +++ b/crates/wasmtime/src/runtime/component/concurrent/table.rs @@ -0,0 +1,316 @@ +// TODO: This duplicates a lot of resource_table.rs; consider reducing that +// duplication. +// +// The main difference between this and resource_table.rs is that the key type, +// `TableId` implements `Copy`, making them much easier to work with than +// `Resource`. I've also added a `Table::delete_any` function, useful for +// implementing `subtask.drop`. + +use std::{any::Any, boxed::Box, collections::BTreeSet, marker::PhantomData, vec::Vec}; + +pub struct TableId { + rep: u32, + _marker: PhantomData T>, +} + +impl TableId { + pub fn new(rep: u32) -> Self { + Self { + rep, + _marker: PhantomData, + } + } +} + +impl Clone for TableId { + fn clone(&self) -> Self { + Self::new(self.rep) + } +} + +impl Copy for TableId {} + +impl TableId { + pub fn rep(&self) -> u32 { + self.rep + } +} + +#[derive(Debug)] +/// Errors returned by operations on `Table` +pub enum TableError { + /// Table has no free keys + Full, + /// Entry not present in table + NotPresent, + /// Resource present in table, but with a different type + WrongType, + /// Entry cannot be deleted because child entrys exist in the table. + HasChildren, +} + +impl std::fmt::Display for TableError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Full => write!(f, "table has no free keys"), + Self::NotPresent => write!(f, "entry not present"), + Self::WrongType => write!(f, "entry is of another type"), + Self::HasChildren => write!(f, "entry has children"), + } + } +} +impl std::error::Error for TableError {} + +/// The `Table` type maps a `TableId` to its entry. +#[derive(Default)] +pub struct Table { + entries: Vec, + free_head: Option, +} + +enum Entry { + Free { next: Option }, + Occupied { entry: TableEntry }, +} + +impl Entry { + pub fn occupied(&self) -> Option<&TableEntry> { + match self { + Self::Occupied { entry } => Some(entry), + Self::Free { .. } => None, + } + } + + pub fn occupied_mut(&mut self) -> Option<&mut TableEntry> { + match self { + Self::Occupied { entry } => Some(entry), + Self::Free { .. } => None, + } + } +} + +/// This structure tracks parent and child relationships for a given table entry. +/// +/// Parents and children are referred to by table index. We maintain the +/// following invariants to prevent orphans and cycles: +/// * parent can only be assigned on creating the entry. +/// * parent, if some, must exist when creating the entry. +/// * whenever a child is created, its index is added to children. +/// * whenever a child is deleted, its index is removed from children. +/// * an entry with children may not be deleted. +struct TableEntry { + /// The entry in the table + entry: Box, + /// The index of the parent of this entry, if it has one. + parent: Option, + /// The indicies of any children of this entry. + children: BTreeSet, +} + +impl TableEntry { + fn new(entry: Box, parent: Option) -> Self { + Self { + entry, + parent, + children: BTreeSet::new(), + } + } + fn add_child(&mut self, child: u32) { + assert!(self.children.insert(child)); + } + fn remove_child(&mut self, child: u32) { + assert!(self.children.remove(&child)); + } +} + +impl Table { + /// Create an empty table + pub fn new() -> Self { + let mut me = Self { + entries: Vec::new(), + free_head: None, + }; + + // TODO: remove this once we've stopped exposing these indexes to guest code: + me.push(()).unwrap(); + + me + } + + /// Inserts a new entry into this table, returning a corresponding + /// `TableId` which can be used to refer to it after it was inserted. + pub fn push(&mut self, entry: T) -> Result, TableError> { + let idx = self.push_(TableEntry::new(Box::new(entry), None))?; + Ok(TableId::new(idx)) + } + + /// Pop an index off of the free list, if it's not empty. + fn pop_free_list(&mut self) -> Option { + if let Some(ix) = self.free_head { + // Advance free_head to the next entry if one is available. + match &self.entries[ix] { + Entry::Free { next } => self.free_head = *next, + Entry::Occupied { .. } => unreachable!(), + } + Some(ix) + } else { + None + } + } + + /// Free an entry in the table, returning its [`TableEntry`]. Add the index to the free list. + fn free_entry(&mut self, ix: usize) -> TableEntry { + let entry = match std::mem::replace( + &mut self.entries[ix], + Entry::Free { + next: self.free_head, + }, + ) { + Entry::Occupied { entry } => entry, + Entry::Free { .. } => unreachable!(), + }; + + self.free_head = Some(ix); + + entry + } + + /// Push a new entry into the table, returning its handle. This will prefer to use free entries + /// if they exist, falling back on pushing new entries onto the end of the table. + fn push_(&mut self, e: TableEntry) -> Result { + if let Some(free) = self.pop_free_list() { + self.entries[free] = Entry::Occupied { entry: e }; + Ok(u32::try_from(free).unwrap()) + } else { + let ix = self + .entries + .len() + .try_into() + .map_err(|_| TableError::Full)?; + self.entries.push(Entry::Occupied { entry: e }); + Ok(ix) + } + } + + fn occupied(&self, key: u32) -> Result<&TableEntry, TableError> { + self.entries + .get(key as usize) + .and_then(Entry::occupied) + .ok_or(TableError::NotPresent) + } + + fn occupied_mut(&mut self, key: u32) -> Result<&mut TableEntry, TableError> { + self.entries + .get_mut(key as usize) + .and_then(Entry::occupied_mut) + .ok_or(TableError::NotPresent) + } + + /// Insert a entry at the next available index, and track that it has a + /// parent entry. + /// + /// The parent must exist to create a child. All child entrys must be + /// destroyed before a parent can be destroyed - otherwise [`Table::delete`] + /// will fail with [`TableError::HasChildren`]. + /// + /// Parent-child relationships are tracked inside the table to ensure that a + /// parent is not deleted while it has live children. This allows children + /// to hold "references" to a parent by table index, to avoid needing + /// e.g. an `Arc>` and the associated locking overhead and + /// design issues, such as child existence extending lifetime of parent + /// referent even after parent is destroyed, possibility for deadlocks. + /// + /// Parent-child relationships may not be modified once created. There is no + /// way to observe these relationships through the [`Table`] methods except + /// for erroring on deletion, or the [`std::fmt::Debug`] impl. + pub fn push_child( + &mut self, + entry: T, + parent: TableId, + ) -> Result, TableError> { + let parent = parent.rep(); + self.occupied(parent)?; + let child = self.push_(TableEntry::new(Box::new(entry), Some(parent)))?; + self.occupied_mut(parent)?.add_child(child); + Ok(TableId::new(child)) + } + + pub fn add_child( + &mut self, + child: TableId, + parent: TableId, + ) -> Result<(), TableError> { + let entry = self.occupied_mut(child.rep())?; + assert!(entry.parent.is_none()); + entry.parent = Some(parent.rep()); + self.occupied_mut(parent.rep())?.add_child(child.rep()); + Ok(()) + } + + pub fn remove_child( + &mut self, + child: TableId, + parent: TableId, + ) -> Result<(), TableError> { + let entry = self.occupied_mut(child.rep())?; + assert_eq!(entry.parent, Some(parent.rep())); + entry.parent = None; + self.occupied_mut(parent.rep())?.remove_child(child.rep()); + Ok(()) + } + + /// Get an immutable reference to a task of a given type at a given index. + /// + /// Multiple shared references can be borrowed at any given time. + pub fn get(&self, key: TableId) -> Result<&T, TableError> { + self.get_(key.rep())? + .downcast_ref() + .ok_or(TableError::WrongType) + } + + fn get_(&self, key: u32) -> Result<&dyn Any, TableError> { + let r = self.occupied(key)?; + Ok(&*r.entry) + } + + /// Get an mutable reference to a task of a given type at a given index. + pub fn get_mut(&mut self, key: TableId) -> Result<&mut T, TableError> { + self.get_mut_(key.rep())? + .downcast_mut() + .ok_or(TableError::WrongType) + } + + pub fn get_mut_(&mut self, key: u32) -> Result<&mut dyn Any, TableError> { + let r = self.occupied_mut(key)?; + Ok(&mut *r.entry) + } + + /// Delete the specified task + pub fn delete(&mut self, key: TableId) -> Result { + self.delete_entry(key.rep())? + .entry + .downcast() + .map(|v| *v) + .map_err(|_| TableError::WrongType) + } + + pub fn delete_any(&mut self, key: u32) -> Result, TableError> { + Ok(self.delete_entry(key)?.entry) + } + + fn delete_entry(&mut self, key: u32) -> Result { + if !self.occupied(key)?.children.is_empty() { + return Err(TableError::HasChildren); + } + let e = self.free_entry(key as usize); + if let Some(parent) = e.parent { + // Remove deleted task from parent's child list. Parent must still + // be present because it cant be deleted while still having + // children: + self.occupied_mut(parent) + .expect("missing parent") + .remove_child(key); + } + Ok(e) + } +} diff --git a/crates/wasmtime/src/runtime/component/func.rs b/crates/wasmtime/src/runtime/component/func.rs index 65687002cb4a..8dcf4398feab 100644 --- a/crates/wasmtime/src/runtime/component/func.rs +++ b/crates/wasmtime/src/runtime/component/func.rs @@ -15,6 +15,9 @@ use wasmtime_environ::component::{ TypeFuncIndex, TypeTuple, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; +#[cfg(feature = "component-model-async")] +use crate::component::concurrent::{self, LiftLowerContext, Promise}; + mod host; mod options; mod typed; @@ -22,6 +25,13 @@ pub use self::host::*; pub use self::options::*; pub use self::typed::*; +#[cfg(feature = "component-model-async")] +type LowerFn = + fn(&mut LowerContext, &Params, InterfaceType, &mut MaybeUninit) -> Result<()>; + +#[cfg(feature = "component-model-async")] +type LiftFn = fn(&mut LiftContext, InterfaceType, &[ValRaw]) -> Result; + #[repr(C)] union ParamsAndResults { params: Params, @@ -36,17 +46,17 @@ union ParamsAndResults { /// [`wasmtime::Func`](crate::Func) it's possible to call functions either /// synchronously or asynchronously and either typed or untyped. #[derive(Copy, Clone, Debug)] -pub struct Func(Stored); +pub struct Func(pub(crate) Stored); #[doc(hidden)] pub struct FuncData { - export: ExportFunction, - ty: TypeFuncIndex, - types: Arc, - options: Options, - instance: Instance, - component_instance: RuntimeComponentInstanceIndex, - post_return: Option, + pub(crate) export: ExportFunction, + pub(crate) ty: TypeFuncIndex, + pub(crate) types: Arc, + pub(crate) options: Options, + pub(crate) instance: Instance, + pub(crate) component_instance: RuntimeComponentInstanceIndex, + pub(crate) post_return: Option, post_return_arg: Option, } @@ -72,7 +82,19 @@ impl Func { ExportFunction { func_ref } }); let component_instance = options.instance; - let options = unsafe { Options::new(store.id(), memory, realloc, options.string_encoding) }; + let callback = options + .callback + .map(|i| data.instance().runtime_callback(i)); + let options = unsafe { + Options::new( + store.id(), + memory, + realloc, + options.string_encoding, + options.async_, + callback, + ) + }; Func(store.store_data_mut().insert(FuncData { export, options, @@ -269,9 +291,9 @@ impl Func { /// Panics if this is called on a function in an asynchronous store. This /// only works with functions defined within a synchronous store. Also /// panics if `store` does not own this function. - pub fn call( + pub fn call( &self, - mut store: impl AsContextMut, + mut store: impl AsContextMut, params: &[Val], results: &mut [Val], ) -> Result<()> { @@ -294,32 +316,98 @@ impl Func { /// only works with functions defined within an asynchronous store. Also /// panics if `store` does not own this function. #[cfg(feature = "async")] - pub async fn call_async( + pub async fn call_async( &self, mut store: impl AsContextMut, params: &[Val], results: &mut [Val], - ) -> Result<()> - where - T: Send, - { - let mut store = store.as_context_mut(); + ) -> Result<()> { + let store = store.as_context_mut(); assert!( store.0.async_support(), "cannot use `call_async` without enabling async support in the config" ); - store - .on_fiber(|store| self.call_impl(store, params, results)) + #[cfg(feature = "component-model-async")] + { + let instance = store.0[self.0].component_instance; + // TODO: do we need to return the store here due to the possible + // invalidation of the reference we were passed? + concurrent::on_fiber(store, Some(instance), move |store| { + self.call_impl(store, params, results) + }) .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let mut store = store; + store + .on_fiber(|store| self.call_impl(store, params, results)) + .await? + } } - fn call_impl( + /// Start concurrent call to this function. + /// + /// Unlike [`Self::call`] and [`Self::call_async`] (both of which require + /// exclusive access to the store until the completion of the call), calls + /// made using this method may run concurrently with other calls to the same + /// instance. + #[cfg(feature = "component-model-async")] + pub async fn call_concurrent( + self, + mut store: impl AsContextMut, + params: Vec, + ) -> Result>> { + let store = store.as_context_mut(); + assert!( + store.0.async_support(), + "cannot use `call_concurrent` when async support is not enabled on the config" + ); + let instance = store.0[self.0].component_instance; + // TODO: do we need to return the store here due to the possible + // invalidation of the reference we were passed? + concurrent::on_fiber(store, Some(instance), move |store| { + self.start_call(store.as_context_mut(), params) + }) + .await? + .0 + } + + #[cfg(feature = "component-model-async")] + fn start_call<'a, T: Send>( + self, + mut store: StoreContextMut<'a, T>, + params: Vec, + ) -> Result>> { + let store = store.as_context_mut(); + + let param_tys = self.params(&store); + if param_tys.len() != params.len() { + bail!( + "expected {} argument(s), got {}", + param_tys.len(), + params.len() + ); + } + + let lower = Self::lower_args as LowerFn<_, _, _>; + let lift = if store.0[self.0].options.async_() { + Self::lift_results_async as LiftFn<_> + } else { + Self::lift_results_sync as LiftFn<_> + }; + + Ok(self.start_call_raw_async(store, params, lower, lift)?.0) + } + + fn call_impl( &self, - mut store: impl AsContextMut, + mut store: impl AsContextMut, params: &[Val], results: &mut [Val], ) -> Result<()> { - let store = &mut store.as_context_mut(); + let store = store.as_context_mut(); let param_tys = self.params(&store); let result_tys = self.results(&store); @@ -333,49 +421,122 @@ impl Func { } if result_tys.len() != results.len() { bail!( - "expected {} results(s), got {}", + "expected {} result(s), got {}", result_tys.len(), results.len() ); } - self.call_raw( - store, - params, - |cx, params, params_ty, dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>| { - let params_ty = match params_ty { - InterfaceType::Tuple(i) => &cx.types[i], - _ => unreachable!(), - }; - if params_ty.abi.flat_count(MAX_FLAT_PARAMS).is_some() { - let dst = &mut unsafe { - mem::transmute::<_, &mut [MaybeUninit; MAX_FLAT_PARAMS]>(dst) - } - .iter_mut(); - - params - .iter() - .zip(params_ty.types.iter()) - .try_for_each(|(param, ty)| param.lower(cx, *ty, dst)) - } else { - self.store_args(cx, ¶ms_ty, params, dst) + if store.0[self.0].options.async_() { + #[cfg(feature = "component-model-async")] + { + for (result, slot) in self + .call_raw_async( + store, + params.iter().cloned().collect(), + Self::lower_args, + Self::lift_results_async, + )? + .0 + .into_iter() + .zip(results) + { + *slot = result; } - }, - |cx, results_ty, src: &[ValRaw; MAX_FLAT_RESULTS]| { - let results_ty = match results_ty { - InterfaceType::Tuple(i) => &cx.types[i], - _ => unreachable!(), - }; - if results_ty.abi.flat_count(MAX_FLAT_RESULTS).is_some() { - let mut flat = src.iter(); - for (ty, slot) in results_ty.types.iter().zip(results) { - *slot = Val::lift(cx, *ty, &mut flat)?; + Ok(()) + } + #[cfg(not(feature = "component-model-async"))] + { + unreachable!( + "async-lifted exports should have failed validation \ + when `component-model-async` feature disabled" + ); + } + } else { + self.call_raw( + store, + ¶ms.iter().cloned().collect::>(), + Self::lower_args, + |cx, results_ty, src: &[ValRaw; MAX_FLAT_RESULTS]| { + for (result, slot) in Self::lift_results_sync(cx, results_ty, src)? + .into_iter() + .zip(results) + { + *slot = result; } Ok(()) - } else { - Self::load_results(cx, results_ty, results, &mut src.iter()) - } + }, + ) + } + } + + #[cfg(feature = "component-model-async")] + fn call_raw_async<'a, T: Send, Params, Return: Send + Sync + 'static, LowerParams>( + &self, + store: StoreContextMut<'a, T>, + params: Params, + lower: LowerFn, + lift: LiftFn, + ) -> Result<(Return, StoreContextMut<'a, T>)> + where + LowerParams: Copy, + { + let me = self.0; + // Note that we smuggle the params through as raw pointers to avoid + // requiring `Params: Send + Sync + 'static` bounds on this function, + // which would prevent passing references as parameters. Technically, + // we don't need to do that for the return type, but we do it anyway for + // symmetry. + // + // This is only safe because `concurrent::call` will either consume or + // drop the contexts before returning. + concurrent::call::<_, LowerParams, _>( + store, + lower_params_with_context::>, + concurrent::LiftLowerContext { + pointer: Box::into_raw(Box::new((me, params, lower))) as _, + dropper: drop_context::<(Stored, Params, LowerFn)>, + }, + lift_results_with_context::>, + concurrent::LiftLowerContext { + pointer: Box::into_raw(Box::new((me, lift))) as _, + dropper: drop_context::<(Stored, LiftFn)>, + }, + *self, + ) + } + + #[cfg(feature = "component-model-async")] + fn start_call_raw_async< + 'a, + T: Send, + Params: Send + Sync + 'static, + Return: Send + Sync + 'static, + LowerParams, + >( + &self, + store: StoreContextMut<'a, T>, + params: Params, + lower: LowerFn, + lift: LiftFn, + ) -> Result<(Promise, StoreContextMut<'a, T>)> + where + LowerParams: Copy, + { + let me = self.0; + concurrent::start_call::<_, LowerParams, _>( + store, + lower_params_with_context::>, + concurrent::LiftLowerContext { + pointer: Box::into_raw(Box::new((me, params, lower))) as _, + dropper: drop_context::<(Stored, Params, LowerFn)>, + }, + lift_results_with_context::>, + concurrent::LiftLowerContext { + pointer: Box::into_raw(Box::new((me, lift))) as _, + dropper: drop_context::<(Stored, LiftFn)>, }, + *self, ) } @@ -389,7 +550,7 @@ impl Func { /// happening. fn call_raw( &self, - store: &mut StoreContextMut<'_, T>, + mut store: StoreContextMut<'_, T>, params: &Params, lower: impl FnOnce( &mut LowerContext<'_, T>, @@ -468,7 +629,7 @@ impl Func { // on the correctness of this module and `ComponentType` // implementations, hence `ComponentType` being an `unsafe` trait. crate::Func::call_unchecked_raw( - store, + &mut store, export.func_ref, NonNull::new(core::ptr::slice_from_raw_parts_mut( space.as_mut_ptr().cast(), @@ -644,8 +805,32 @@ impl Func { Ok(()) } + fn lower_args( + cx: &mut LowerContext<'_, T>, + params: &Vec, + params_ty: InterfaceType, + dst: &mut MaybeUninit<[ValRaw; MAX_FLAT_PARAMS]>, + ) -> Result<()> { + let params_ty = match params_ty { + InterfaceType::Tuple(i) => &cx.types[i], + _ => unreachable!(), + }; + if params_ty.abi.flat_count(MAX_FLAT_PARAMS).is_some() { + let dst = &mut unsafe { + mem::transmute::<_, &mut [MaybeUninit; MAX_FLAT_PARAMS]>(dst) + } + .iter_mut(); + + params + .iter() + .zip(params_ty.types.iter()) + .try_for_each(|(param, ty)| param.lower(cx, *ty, dst)) + } else { + Self::store_args(cx, ¶ms_ty, params, dst) + } + } + fn store_args( - &self, cx: &mut LowerContext<'_, T>, params_ty: &TypeTuple, args: &[Val], @@ -664,12 +849,55 @@ impl Func { Ok(()) } + fn lift_results_sync( + cx: &mut LiftContext<'_>, + results_ty: InterfaceType, + src: &[ValRaw], + ) -> Result> { + Self::lift_results(cx, results_ty, src, false) + } + + #[cfg(feature = "component-model-async")] + fn lift_results_async( + cx: &mut LiftContext<'_>, + results_ty: InterfaceType, + src: &[ValRaw], + ) -> Result> { + Self::lift_results(cx, results_ty, src, true) + } + + fn lift_results( + cx: &mut LiftContext<'_>, + results_ty: InterfaceType, + src: &[ValRaw], + async_: bool, + ) -> Result> { + let results_ty = match results_ty { + InterfaceType::Tuple(i) => &cx.types[i], + _ => unreachable!(), + }; + let limit = if async_ { + MAX_FLAT_PARAMS + } else { + MAX_FLAT_RESULTS + }; + if results_ty.abi.flat_count(limit).is_some() { + let mut flat = src.iter(); + results_ty + .types + .iter() + .map(|ty| Val::lift(cx, *ty, &mut flat)) + .collect() + } else { + Self::load_results(cx, results_ty, &mut src.iter()) + } + } + fn load_results( cx: &mut LiftContext<'_>, results_ty: &TypeTuple, - results: &mut [Val], src: &mut core::slice::Iter<'_, ValRaw>, - ) -> Result<()> { + ) -> Result> { // FIXME(#4311): needs to read an i64 for memory64 let ptr = usize::try_from(src.next().unwrap().get_u32())?; if ptr % usize::try_from(results_ty.abi.align32)? != 0 { @@ -683,11 +911,156 @@ impl Func { .ok_or_else(|| anyhow::anyhow!("pointer out of bounds of memory"))?; let mut offset = 0; - for (ty, slot) in results_ty.types.iter().zip(results) { - let abi = cx.types.canonical_abi(ty); - let offset = abi.next_field32_size(&mut offset); - *slot = Val::load(cx, *ty, &bytes[offset..][..abi.size32 as usize])?; + results_ty + .types + .iter() + .map(|ty| { + let abi = cx.types.canonical_abi(ty); + let offset = abi.next_field32_size(&mut offset); + Val::load(cx, *ty, &bytes[offset..][..abi.size32 as usize]) + }) + .collect() + } +} + +#[cfg(feature = "component-model-async")] +fn drop_context(pointer: *mut u8) { + drop(unsafe { Box::from_raw(pointer as *mut T) }) +} + +#[cfg(feature = "component-model-async")] +fn lower_params_with_context< + Params, + LowerParams, + T, + F: FnOnce( + &mut LowerContext, + &Params, + InterfaceType, + &mut MaybeUninit, + ) -> Result<()> + + Send + + Sync, +>( + context: LiftLowerContext, + store: *mut dyn crate::vm::VMStore, + lowered: &mut [MaybeUninit], +) -> Result<()> { + let (me, params, lower) = unsafe { + *Box::from_raw( + std::mem::ManuallyDrop::new(context).pointer as *mut (Stored, Params, F), + ) + }; + + lower_params(store, lowered, me, params, lower) +} + +#[cfg(feature = "component-model-async")] +fn lower_params< + Params, + LowerParams, + T, + F: FnOnce( + &mut LowerContext, + &Params, + InterfaceType, + &mut MaybeUninit, + ) -> Result<()> + + Send + + Sync, +>( + store: *mut dyn crate::vm::VMStore, + lowered: &mut [MaybeUninit], + me: Stored, + params: Params, + lower: F, +) -> Result<()> { + use crate::component::storage::slice_to_storage_mut; + + let mut store = unsafe { StoreContextMut(&mut *store.cast()) }; + let FuncData { + options, + instance, + component_instance, + ty, + .. + } = store.0[me]; + + let instance = store.0[instance.0].as_ref().unwrap(); + let types = instance.component_types().clone(); + let instance_ptr = instance.instance_ptr(); + let mut flags = instance.instance().instance_flags(component_instance); + + unsafe { + if !flags.may_enter() { + bail!(crate::Trap::CannotEnterComponent); + } + + flags.set_may_leave(false); + let mut cx = LowerContext::new(store.as_context_mut(), &options, &types, instance_ptr); + let result = lower( + &mut cx, + ¶ms, + InterfaceType::Tuple(types[ty].params), + slice_to_storage_mut(lowered), + ); + flags.set_may_leave(true); + result?; + + if !options.async_() { + flags.set_may_enter(false); + flags.set_needs_post_return(true); } + Ok(()) } } + +#[cfg(feature = "component-model-async")] +fn lift_results_with_context< + Return: Send + Sync + 'static, + T, + F: FnOnce(&mut LiftContext, InterfaceType, &[ValRaw]) -> Result + Send + Sync, +>( + context: LiftLowerContext, + store: *mut dyn crate::vm::VMStore, + lowered: &[ValRaw], +) -> Result> { + let (me, lift) = unsafe { + *Box::from_raw(std::mem::ManuallyDrop::new(context).pointer as *mut (Stored, F)) + }; + + lift_results::<_, T, _>(store, lowered, me, lift) +} + +#[cfg(feature = "component-model-async")] +fn lift_results< + Return: Send + Sync + 'static, + T, + F: FnOnce(&mut LiftContext, InterfaceType, &[ValRaw]) -> Result + Send + Sync, +>( + store: *mut dyn crate::vm::VMStore, + lowered: &[ValRaw], + me: Stored, + lift: F, +) -> Result> { + let store = unsafe { StoreContextMut::(&mut *store.cast()) }; + let FuncData { + options, + instance, + ty, + .. + } = store.0[me]; + + let instance = store.0[instance.0].as_ref().unwrap(); + let types = instance.component_types().clone(); + let instance_ptr = instance.instance_ptr(); + + unsafe { + Ok(Box::new(lift( + &mut LiftContext::new(store.0, &options, &types, instance_ptr), + InterfaceType::Tuple(types[ty].results), + lowered, + )?) as Box) + } +} diff --git a/crates/wasmtime/src/runtime/component/func/host.rs b/crates/wasmtime/src/runtime/component/func/host.rs index a8c593286bde..1c1130487be2 100644 --- a/crates/wasmtime/src/runtime/component/func/host.rs +++ b/crates/wasmtime/src/runtime/component/func/host.rs @@ -1,3 +1,4 @@ +use crate::component::concurrent; use crate::component::func::{LiftContext, LowerContext, Options}; use crate::component::matching::InstanceType; use crate::component::storage::slice_to_storage_mut; @@ -10,13 +11,28 @@ use crate::runtime::vm::{VMFuncRef, VMGlobalDefinition, VMMemoryDefinition, VMOp use crate::{AsContextMut, CallHook, StoreContextMut, ValRaw}; use alloc::sync::Arc; use core::any::Any; +use core::future::Future; +use core::iter; use core::mem::{self, MaybeUninit}; use core::ptr::NonNull; use wasmtime_environ::component::{ - CanonicalAbiInfo, ComponentTypes, InterfaceType, StringEncoding, TypeFuncIndex, - MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + CanonicalAbiInfo, ComponentTypes, InterfaceType, RuntimeComponentInstanceIndex, StringEncoding, + TypeFuncIndex, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; +#[cfg(feature = "component-model-async")] +use crate::runtime::vm::SendSyncPtr; + +#[cfg(feature = "component-model-async")] +const STATUS_PARAMS_READ: u32 = 1; +#[cfg(feature = "component-model-async")] +const STATUS_DONE: u32 = 3; + +struct Ptr(*const F); + +unsafe impl Sync for Ptr {} +unsafe impl Send for Ptr {} + pub struct HostFunc { entrypoint: VMLoweringCallee, typecheck: Box) -> Result<()>) + Send + Sync>, @@ -28,9 +44,23 @@ impl HostFunc { where F: Fn(StoreContextMut, P) -> Result + Send + Sync + 'static, P: ComponentNamedList + Lift + 'static, - R: ComponentNamedList + Lower + 'static, + R: ComponentNamedList + Lower + Send + Sync + 'static, { - let entrypoint = Self::entrypoint::; + Self::from_concurrent(move |store, params| { + let result = func(store, params); + async move { concurrent::for_any(move |_| result) } + }) + } + + pub(crate) fn from_concurrent(func: F) -> Arc + where + N: FnOnce(StoreContextMut) -> Result + Send + Sync + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, P) -> FN + Send + Sync + 'static, + P: ComponentNamedList + Lift + 'static, + R: ComponentNamedList + Lower + Send + Sync + 'static, + { + let entrypoint = Self::entrypoint::; Arc::new(HostFunc { entrypoint, typecheck: Box::new(typecheck::), @@ -38,10 +68,11 @@ impl HostFunc { }) } - extern "C" fn entrypoint( + extern "C" fn entrypoint( cx: NonNull, data: NonNull, ty: u32, + caller_instance: u32, flags: NonNull, memory: *mut VMMemoryDefinition, realloc: *mut VMFuncRef, @@ -51,25 +82,28 @@ impl HostFunc { storage_len: usize, ) -> bool where - F: Fn(StoreContextMut, P) -> Result, + N: FnOnce(StoreContextMut) -> Result + Send + Sync + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, P) -> FN + Send + Sync + 'static, P: ComponentNamedList + Lift + 'static, - R: ComponentNamedList + Lower + 'static, + R: ComponentNamedList + Lower + Send + Sync + 'static, { - let data = data.as_ptr() as *const F; + let data = Ptr(data.as_ptr() as *const F); unsafe { call_host_and_handle_result::(cx, |instance, types, store| { - call_host::<_, _, _, _>( + call_host( instance, types, store, TypeFuncIndex::from_u32(ty), + RuntimeComponentInstanceIndex::from_u32(caller_instance), InstanceFlags::from_raw(flags), memory, realloc, StringEncoding::from_u8(string_encoding).unwrap(), async_ != 0, NonNull::slice_from_raw_parts(storage, storage_len).as_mut(), - |store, args| (*data)(store, args), + move |store, args| (*data.0)(store, args), ) }) } @@ -78,14 +112,30 @@ impl HostFunc { pub(crate) fn new_dynamic(func: F) -> Arc where F: Fn(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()> + Send + Sync + 'static, + { + Self::new_dynamic_concurrent(move |store, params: Vec, result_count| { + let mut results = iter::repeat(Val::Bool(false)) + .take(result_count) + .collect::>(); + let result = func(store, ¶ms, &mut results); + let result = result.map(move |()| results); + async move { concurrent::for_any(move |_| result) } + }) + } + + pub(crate) fn new_dynamic_concurrent(f: F) -> Arc + where + N: FnOnce(StoreContextMut) -> Result> + Send + Sync + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, Vec, usize) -> FN + Send + Sync + 'static, { Arc::new(HostFunc { - entrypoint: dynamic_entrypoint::, + entrypoint: dynamic_entrypoint::, // This function performs dynamic type checks and subsequently does // not need to perform up-front type checks. Instead everything is // dynamically managed at runtime. typecheck: Box::new(move |_expected_index, _expected_types| Ok(())), - func: Box::new(func), + func: Box::new(f), }) } @@ -135,11 +185,12 @@ where /// This function is in general `unsafe` as the validity of all the parameters /// must be upheld. Generally that's done by ensuring this is only called from /// the select few places it's intended to be called from. -unsafe fn call_host( +unsafe fn call_host( instance: *mut ComponentInstance, types: &Arc, mut cx: StoreContextMut<'_, T>, ty: TypeFuncIndex, + caller_instance: RuntimeComponentInstanceIndex, mut flags: InstanceFlags, memory: *mut VMMemoryDefinition, realloc: *mut VMFuncRef, @@ -149,14 +200,12 @@ unsafe fn call_host( closure: F, ) -> Result<()> where + N: FnOnce(StoreContextMut) -> Result + Send + Sync + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, Params) -> FN + 'static, Params: Lift, - Return: Lower, - F: FnOnce(StoreContextMut<'_, T>, Params) -> Result, + Return: Lower + Send + Sync + 'static, { - if async_ { - todo!() - } - /// Representation of arguments to this function when a return pointer is in /// use, namely the argument list is followed by a single value which is the /// return pointer. @@ -180,6 +229,8 @@ where NonNull::new(memory), NonNull::new(realloc), string_encoding, + async_, + None, ); // Perform a dynamic check that this instance can indeed be left. Exiting @@ -193,39 +244,85 @@ where let param_tys = InterfaceType::Tuple(ty.params); let result_tys = InterfaceType::Tuple(ty.results); - // There's a 2x2 matrix of whether parameters and results are stored on the - // stack or on the heap. Each of the 4 branches here have a different - // representation of the storage of arguments/returns. - // - // Also note that while four branches are listed here only one is taken for - // any particular `Params` and `Return` combination. This should be - // trivially DCE'd by LLVM. Perhaps one day with enough const programming in - // Rust we can make monomorphizations of this function codegen only one - // branch, but today is not that day. - let mut storage: Storage<'_, Params, Return> = if Params::flatten_count() <= MAX_FLAT_PARAMS { - if Return::flatten_count() <= MAX_FLAT_RESULTS { - Storage::Direct(slice_to_storage_mut(storage)) - } else { - Storage::ResultsIndirect(slice_to_storage_mut(storage).assume_init_ref()) + if async_ { + #[cfg(feature = "component-model-async")] + { + let paramptr = storage[0].assume_init(); + let retptr = storage[1].assume_init(); + + let params = { + let lift = &mut LiftContext::new(cx.0, &options, types, instance); + lift.enter_call(); + let ptr = validate_inbounds::(lift.memory(), ¶mptr)?; + Params::load(lift, param_tys, &lift.memory()[ptr..][..Params::SIZE32])? + }; + + let future = closure(cx.as_context_mut(), params); + + let task = + concurrent::first_poll(instance, cx.as_context_mut(), future, caller_instance, { + let types = types.clone(); + let instance = SendSyncPtr::new(NonNull::new(instance).unwrap()); + move |cx, ret: Return| { + let mut lower = LowerContext::new(cx, &options, &types, instance.as_ptr()); + let ptr = validate_inbounds::(lower.as_slice_mut(), &retptr)?; + ret.store(&mut lower, result_tys, ptr) + } + })?; + + let status = if let Some(task) = task { + (STATUS_PARAMS_READ << 30) | task + } else { + STATUS_DONE << 30 + }; + + storage[0] = MaybeUninit::new(ValRaw::i32(status as i32)); + } + #[cfg(not(feature = "component-model-async"))] + { + unreachable!( + "async-lowered imports should have failed validation \ + when `component-model-async` feature disabled" + ); } } else { - if Return::flatten_count() <= MAX_FLAT_RESULTS { - Storage::ParamsIndirect(slice_to_storage_mut(storage)) + // There's a 2x2 matrix of whether parameters and results are stored on the + // stack or on the heap. Each of the 4 branches here have a different + // representation of the storage of arguments/returns. + // + // Also note that while four branches are listed here only one is taken for + // any particular `Params` and `Return` combination. This should be + // trivially DCE'd by LLVM. Perhaps one day with enough const programming in + // Rust we can make monomorphizations of this function codegen only one + // branch, but today is not that day. + let mut storage: Storage<'_, Params, Return> = if Params::flatten_count() <= MAX_FLAT_PARAMS + { + if Return::flatten_count() <= MAX_FLAT_RESULTS { + Storage::Direct(slice_to_storage_mut(storage)) + } else { + Storage::ResultsIndirect(slice_to_storage_mut(storage).assume_init_ref()) + } } else { - Storage::Indirect(slice_to_storage_mut(storage).assume_init_ref()) - } - }; - let mut lift = LiftContext::new(cx.0, &options, types, instance); - lift.enter_call(); - let params = storage.lift_params(&mut lift, param_tys)?; + if Return::flatten_count() <= MAX_FLAT_RESULTS { + Storage::ParamsIndirect(slice_to_storage_mut(storage)) + } else { + Storage::Indirect(slice_to_storage_mut(storage).assume_init_ref()) + } + }; + let mut lift = LiftContext::new(cx.0, &options, types, instance); + lift.enter_call(); + let params = storage.lift_params(&mut lift, param_tys)?; + + let future = closure(cx.as_context_mut(), params); - let ret = closure(cx.as_context_mut(), params)?; - flags.set_may_leave(false); - let mut lower = LowerContext::new(cx, &options, types, instance); - storage.lower_results(&mut lower, result_tys, ret)?; - flags.set_may_leave(true); + let (ret, cx) = concurrent::poll_and_block(cx, future, caller_instance)?; - lower.exit_call()?; + flags.set_may_leave(false); + let mut lower = LowerContext::new(cx, &options, types, instance); + storage.lower_results(&mut lower, result_tys, ret)?; + flags.set_may_leave(true); + lower.exit_call()?; + } return Ok(()); @@ -280,7 +377,7 @@ where } } -fn validate_inbounds(memory: &[u8], ptr: &ValRaw) -> Result { +pub(crate) fn validate_inbounds(memory: &[u8], ptr: &ValRaw) -> Result { // FIXME(#4311): needs memory64 support let ptr = usize::try_from(ptr.get_u32())?; if ptr % usize::try_from(T::ALIGN32)? != 0 { @@ -318,11 +415,12 @@ unsafe fn call_host_and_handle_result( }) } -unsafe fn call_host_dynamic( +unsafe fn call_host_dynamic( instance: *mut ComponentInstance, types: &Arc, mut store: StoreContextMut<'_, T>, ty: TypeFuncIndex, + caller_instance: RuntimeComponentInstanceIndex, mut flags: InstanceFlags, memory: *mut VMMemoryDefinition, realloc: *mut VMFuncRef, @@ -332,17 +430,17 @@ unsafe fn call_host_dynamic( closure: F, ) -> Result<()> where - F: FnOnce(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()>, + N: FnOnce(StoreContextMut) -> Result> + Send + Sync + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, Vec, usize) -> FN + 'static, { - if async_ { - todo!() - } - let options = Options::new( store.0.id(), NonNull::new(memory), NonNull::new(realloc), string_encoding, + async_, + None, ); // Perform a dynamic check that this instance can indeed be left. Exiting @@ -358,66 +456,145 @@ where let func_ty = &types[ty]; let param_tys = &types[func_ty.params]; let result_tys = &types[func_ty.results]; - let mut cx = LiftContext::new(store.0, &options, types, instance); - cx.enter_call(); - if let Some(param_count) = param_tys.abi.flat_count(MAX_FLAT_PARAMS) { - // NB: can use `MaybeUninit::slice_assume_init_ref` when that's stable - let mut iter = - mem::transmute::<&[MaybeUninit], &[ValRaw]>(&storage[..param_count]).iter(); - args = param_tys - .types - .iter() - .map(|ty| Val::lift(&mut cx, *ty, &mut iter)) - .collect::>>()?; - ret_index = param_count; - assert!(iter.next().is_none()); - } else { - let mut offset = - validate_inbounds_dynamic(¶m_tys.abi, cx.memory(), storage[0].assume_init_ref())?; - args = param_tys - .types - .iter() - .map(|ty| { - let abi = types.canonical_abi(ty); - let size = usize::try_from(abi.size32).unwrap(); - let memory = &cx.memory()[abi.next_field32_size(&mut offset)..][..size]; - Val::load(&mut cx, *ty, memory) - }) - .collect::>>()?; - ret_index = 1; - }; - let mut result_vals = Vec::with_capacity(result_tys.types.len()); - for _ in result_tys.types.iter() { - result_vals.push(Val::Bool(false)); - } - closure(store.as_context_mut(), &args, &mut result_vals)?; - flags.set_may_leave(false); - - let mut cx = LowerContext::new(store, &options, types, instance); - if let Some(cnt) = result_tys.abi.flat_count(MAX_FLAT_RESULTS) { - let mut dst = storage[..cnt].iter_mut(); - for (val, ty) in result_vals.iter().zip(result_tys.types.iter()) { - val.lower(&mut cx, *ty, &mut dst)?; + if async_ { + #[cfg(feature = "component-model-async")] + { + let paramptr = storage[0].assume_init(); + let retptr = storage[1].assume_init(); + + let params = { + let mut lift = &mut LiftContext::new(store.0, &options, types, instance); + lift.enter_call(); + let mut offset = + validate_inbounds_dynamic(¶m_tys.abi, lift.memory(), ¶mptr)?; + param_tys + .types + .iter() + .map(|ty| { + let abi = types.canonical_abi(ty); + let size = usize::try_from(abi.size32).unwrap(); + let memory = &lift.memory()[abi.next_field32_size(&mut offset)..][..size]; + Val::load(&mut lift, *ty, memory) + }) + .collect::>>()? + }; + + let future = closure(store.as_context_mut(), params, result_tys.types.len()); + + let task = concurrent::first_poll( + instance, + store.as_context_mut(), + future, + caller_instance, + { + let types = types.clone(); + let instance = SendSyncPtr::new(NonNull::new(instance).unwrap()); + let result_tys = func_ty.results; + move |store, result_vals: Vec| { + let result_tys = &types[result_tys]; + if result_vals.len() != result_tys.types.len() { + bail!("result length mismatch"); + } + + let mut lower = + LowerContext::new(store, &options, &types, instance.as_ptr()); + let mut ptr = validate_inbounds_dynamic( + &result_tys.abi, + lower.as_slice_mut(), + &retptr, + )?; + for (val, ty) in result_vals.iter().zip(result_tys.types.iter()) { + let offset = types.canonical_abi(ty).next_field32_size(&mut ptr); + val.store(&mut lower, *ty, offset)?; + } + Ok(()) + } + }, + )?; + + let status = if let Some(task) = task { + (STATUS_PARAMS_READ << 30) | task + } else { + STATUS_DONE << 30 + }; + + storage[0] = MaybeUninit::new(ValRaw::i32(status as i32)); + } + #[cfg(not(feature = "component-model-async"))] + { + unreachable!( + "async-lowered imports should have failed validation \ + when `component-model-async` feature disabled" + ); } - assert!(dst.next().is_none()); } else { - let ret_ptr = storage[ret_index].assume_init_ref(); - let mut ptr = validate_inbounds_dynamic(&result_tys.abi, cx.as_slice_mut(), ret_ptr)?; - for (val, ty) in result_vals.iter().zip(result_tys.types.iter()) { - let offset = types.canonical_abi(ty).next_field32_size(&mut ptr); - val.store(&mut cx, *ty, offset)?; + let mut cx = LiftContext::new(store.0, &options, types, instance); + cx.enter_call(); + if let Some(param_count) = param_tys.abi.flat_count(MAX_FLAT_PARAMS) { + // NB: can use `MaybeUninit::slice_assume_init_ref` when that's stable + let mut iter = + mem::transmute::<&[MaybeUninit], &[ValRaw]>(&storage[..param_count]).iter(); + args = param_tys + .types + .iter() + .map(|ty| Val::lift(&mut cx, *ty, &mut iter)) + .collect::>>()?; + ret_index = param_count; + assert!(iter.next().is_none()); + } else { + let mut offset = validate_inbounds_dynamic( + ¶m_tys.abi, + cx.memory(), + storage[0].assume_init_ref(), + )?; + args = param_tys + .types + .iter() + .map(|ty| { + let abi = types.canonical_abi(ty); + let size = usize::try_from(abi.size32).unwrap(); + let memory = &cx.memory()[abi.next_field32_size(&mut offset)..][..size]; + Val::load(&mut cx, *ty, memory) + }) + .collect::>>()?; + ret_index = 1; + }; + + let future = closure(store.as_context_mut(), args, result_tys.types.len()); + let (result_vals, store) = concurrent::poll_and_block(store, future, caller_instance)?; + + flags.set_may_leave(false); + + let mut cx = LowerContext::new(store, &options, types, instance); + if let Some(cnt) = result_tys.abi.flat_count(MAX_FLAT_RESULTS) { + let mut dst = storage[..cnt].iter_mut(); + for (val, ty) in result_vals.iter().zip(result_tys.types.iter()) { + val.lower(&mut cx, *ty, &mut dst)?; + } + assert!(dst.next().is_none()); + } else { + let ret_ptr = storage[ret_index].assume_init_ref(); + let mut ptr = validate_inbounds_dynamic(&result_tys.abi, cx.as_slice_mut(), ret_ptr)?; + for (val, ty) in result_vals.iter().zip(result_tys.types.iter()) { + let offset = types.canonical_abi(ty).next_field32_size(&mut ptr); + val.store(&mut cx, *ty, offset)?; + } } - } - flags.set_may_leave(true); + flags.set_may_leave(true); - cx.exit_call()?; + cx.exit_call()?; + } return Ok(()); } -fn validate_inbounds_dynamic(abi: &CanonicalAbiInfo, memory: &[u8], ptr: &ValRaw) -> Result { +pub(crate) fn validate_inbounds_dynamic( + abi: &CanonicalAbiInfo, + memory: &[u8], + ptr: &ValRaw, +) -> Result { // FIXME(#4311): needs memory64 support let ptr = usize::try_from(ptr.get_u32())?; if ptr % usize::try_from(abi.align32)? != 0 { @@ -433,10 +610,11 @@ fn validate_inbounds_dynamic(abi: &CanonicalAbiInfo, memory: &[u8], ptr: &ValRaw Ok(ptr) } -extern "C" fn dynamic_entrypoint( +extern "C" fn dynamic_entrypoint( cx: NonNull, data: NonNull, ty: u32, + caller_instance: u32, flags: NonNull, memory: *mut VMMemoryDefinition, realloc: *mut VMFuncRef, @@ -446,23 +624,26 @@ extern "C" fn dynamic_entrypoint( storage_len: usize, ) -> bool where - F: Fn(StoreContextMut<'_, T>, &[Val], &mut [Val]) -> Result<()> + Send + Sync + 'static, + N: FnOnce(StoreContextMut) -> Result> + Send + Sync + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, Vec, usize) -> FN + Send + Sync + 'static, { - let data = data.as_ptr() as *const F; + let data = Ptr(data.as_ptr() as *const F); unsafe { call_host_and_handle_result(cx, |instance, types, store| { - call_host_dynamic::( + call_host_dynamic( instance, types, store, TypeFuncIndex::from_u32(ty), + RuntimeComponentInstanceIndex::from_u32(caller_instance), InstanceFlags::from_raw(flags), memory, realloc, StringEncoding::from_u8(string_encoding).unwrap(), async_ != 0, NonNull::slice_from_raw_parts(storage, storage_len).as_mut(), - |store, params, results| (*data)(store, params, results), + move |store, params, results| (*data.0)(store, params, results), ) }) } diff --git a/crates/wasmtime/src/runtime/component/func/options.rs b/crates/wasmtime/src/runtime/component/func/options.rs index 20bfa88709f5..78b2d77b1023 100644 --- a/crates/wasmtime/src/runtime/component/func/options.rs +++ b/crates/wasmtime/src/runtime/component/func/options.rs @@ -43,6 +43,11 @@ pub struct Options { /// /// This defaults to utf-8 but can be changed if necessary. string_encoding: StringEncoding, + + async_: bool, + + #[cfg_attr(not(feature = "component-model-async"), allow(unused))] + pub(crate) callback: Option>, } // The `Options` structure stores raw pointers but they're never used unless a @@ -66,12 +71,16 @@ impl Options { memory: Option>, realloc: Option>, string_encoding: StringEncoding, + async_: bool, + callback: Option>, ) -> Options { Options { store_id, memory, realloc, string_encoding, + async_, + callback, } } @@ -163,6 +172,11 @@ impl Options { pub fn store_id(&self) -> StoreId { self.store_id } + + /// Returns whether this lifting or lowering uses the async ABI. + pub fn async_(&self) -> bool { + self.async_ + } } /// A helper structure which is a "package" of the context used during lowering @@ -196,7 +210,7 @@ pub struct LowerContext<'a, T> { /// into. /// /// This pointer is required to be owned by the `store` provided. - instance: *mut ComponentInstance, + pub(crate) instance: *mut ComponentInstance, } #[doc(hidden)] @@ -402,7 +416,7 @@ pub struct LiftContext<'a> { memory: Option<&'a [u8]>, - instance: *mut ComponentInstance, + pub(crate) instance: *mut ComponentInstance, host_table: &'a mut ResourceTable, host_resource_data: &'a mut HostResourceData, diff --git a/crates/wasmtime/src/runtime/component/func/typed.rs b/crates/wasmtime/src/runtime/component/func/typed.rs index 534de7c821b3..acf7e8a29bc3 100644 --- a/crates/wasmtime/src/runtime/component/func/typed.rs +++ b/crates/wasmtime/src/runtime/component/func/typed.rs @@ -18,7 +18,7 @@ use wasmtime_environ::component::{ }; #[cfg(feature = "component-model-async")] -use crate::component::concurrent::Promise; +use crate::component::concurrent::{self, Promise}; /// A statically-typed version of [`Func`] which takes `Params` as input and /// returns `Return`. @@ -157,7 +157,14 @@ where /// Panics if this is called on a function in an asynchronous store. This /// only works with functions defined within a synchronous store. Also /// panics if `store` does not own this function. - pub fn call(&self, store: impl AsContextMut, params: Params) -> Result { + pub fn call( + &self, + store: impl AsContextMut, + params: Params, + ) -> Result + where + Return: Send + Sync + 'static, + { assert!( !store.as_context().async_support(), "must use `call_async` when async support is enabled on the config" @@ -173,24 +180,38 @@ where /// only works with functions defined within an asynchronous store. Also /// panics if `store` does not own this function. #[cfg(feature = "async")] - pub async fn call_async( - &self, + pub async fn call_async( + self, mut store: impl AsContextMut, params: Params, ) -> Result where - T: Send, Params: Send + Sync, - Return: Send + Sync, + Return: Send + Sync + 'static, { - let mut store = store.as_context_mut(); + let store = store.as_context_mut(); assert!( store.0.async_support(), "cannot use `call_async` when async support is not enabled on the config" ); - store - .on_fiber(|store| self.call_impl(store, params)) + #[cfg(feature = "component-model-async")] + { + let instance = store.0[self.func.0].component_instance; + // TODO: do we need to return the store here due to the possible + // invalidation of the reference we were passed? + concurrent::on_fiber(store, Some(instance), move |store| { + self.call_impl(store, params) + }) .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let mut store = store; + store + .on_fiber(|store| self.call_impl(store, params)) + .await? + } } /// Start concurrent call to this function. @@ -214,12 +235,162 @@ where store.0.async_support(), "cannot use `call_concurrent` when async support is not enabled on the config" ); - _ = params; - todo!() + let instance = store.0[self.func.0].component_instance; + // TODO: do we need to return the store here due to the possible + // invalidation of the reference we were passed? + concurrent::on_fiber(store, Some(instance), move |store| { + self.start_call(store.as_context_mut(), params) + }) + .await? + .0 + } + + #[cfg(feature = "component-model-async")] + fn start_call<'a, T: Send>( + self, + store: StoreContextMut<'a, T>, + params: Params, + ) -> Result> + where + Params: Send + Sync + 'static, + Return: Send + Sync + 'static, + { + Ok(if store.0[self.func.0].options.async_() { + #[cfg(feature = "component-model-async")] + { + if Params::flatten_count() <= MAX_FLAT_PARAMS { + if Return::flatten_count() <= MAX_FLAT_PARAMS { + self.func.start_call_raw_async( + store, + params, + Self::lower_stack_args, + Self::lift_stack_result_raw, + ) + } else { + self.func.start_call_raw_async( + store, + params, + Self::lower_stack_args, + Self::lift_heap_result_raw, + ) + } + } else { + if Return::flatten_count() <= MAX_FLAT_PARAMS { + self.func.start_call_raw_async( + store, + params, + Self::lower_heap_args, + Self::lift_stack_result_raw, + ) + } else { + self.func.start_call_raw_async( + store, + params, + Self::lower_heap_args, + Self::lift_heap_result_raw, + ) + } + } + } + #[cfg(not(feature = "component-model-async"))] + { + unreachable!( + "async-lifted exports should have failed validation \ + when `component-model-async` feature disabled" + ); + } + } else if Params::flatten_count() <= MAX_FLAT_PARAMS { + if Return::flatten_count() <= MAX_FLAT_RESULTS { + self.func.start_call_raw_async( + store, + params, + Self::lower_stack_args, + Self::lift_stack_result_raw, + ) + } else { + self.func.start_call_raw_async( + store, + params, + Self::lower_stack_args, + Self::lift_heap_result_raw, + ) + } + } else { + if Return::flatten_count() <= MAX_FLAT_RESULTS { + self.func.start_call_raw_async( + store, + params, + Self::lower_heap_args, + Self::lift_stack_result_raw, + ) + } else { + self.func.start_call_raw_async( + store, + params, + Self::lower_heap_args, + Self::lift_heap_result_raw, + ) + } + }? + .0) } - fn call_impl(&self, mut store: impl AsContextMut, params: Params) -> Result { - let store = &mut store.as_context_mut(); + fn call_impl( + &self, + mut store: impl AsContextMut, + params: Params, + ) -> Result + where + Return: Send + Sync + 'static, + { + let store = store.as_context_mut(); + + if store.0[self.func.0].options.async_() { + #[cfg(feature = "component-model-async")] + { + return Ok(if Params::flatten_count() <= MAX_FLAT_PARAMS { + if Return::flatten_count() <= MAX_FLAT_PARAMS { + self.func.call_raw_async( + store, + params, + Self::lower_stack_args, + Self::lift_stack_result_raw, + ) + } else { + self.func.call_raw_async( + store, + params, + Self::lower_stack_args, + Self::lift_heap_result_raw, + ) + } + } else { + if Return::flatten_count() <= MAX_FLAT_PARAMS { + self.func.call_raw_async( + store, + params, + Self::lower_heap_args, + Self::lift_stack_result_raw, + ) + } else { + self.func.call_raw_async( + store, + params, + Self::lower_heap_args, + Self::lift_heap_result_raw, + ) + } + }? + .0); + } + #[cfg(not(feature = "component-model-async"))] + { + bail!( + "must enable the `component-model-async` feature to call async-lifted exports" + ) + } + } + // Note that this is in theory simpler than it might read at this time. // Here we're doing a runtime dispatch on the `flatten_count` for the // params/results to see whether they're inbounds. This creates 4 cases @@ -294,8 +465,6 @@ where ty: InterfaceType, dst: &mut MaybeUninit, ) -> Result<()> { - assert!(Params::flatten_count() > MAX_FLAT_PARAMS); - // Memory must exist via validation if the arguments are stored on the // heap, so we can create a `MemoryMut` at this point. Afterwards // `realloc` is used to allocate space for all the arguments and then @@ -330,10 +499,20 @@ where ty: InterfaceType, dst: &Return::Lower, ) -> Result { - assert!(Return::flatten_count() <= MAX_FLAT_RESULTS); Return::lift(cx, ty, dst) } + #[cfg(feature = "component-model-async")] + fn lift_stack_result_raw( + cx: &mut LiftContext<'_>, + ty: InterfaceType, + dst: &[ValRaw], + ) -> Result { + Self::lift_stack_result(cx, ty, unsafe { + crate::component::storage::slice_to_storage(dst) + }) + } + /// Lift the result of a function where the result is stored indirectly on /// the heap. fn lift_heap_result( @@ -356,6 +535,15 @@ where Return::load(cx, ty, bytes) } + #[cfg(feature = "component-model-async")] + fn lift_heap_result_raw( + cx: &mut LiftContext<'_>, + ty: InterfaceType, + dst: &[ValRaw], + ) -> Result { + Self::lift_heap_result(cx, ty, &dst[0]) + } + /// See [`Func::post_return`] pub fn post_return(&self, store: impl AsContextMut) -> Result<()> { self.func.post_return(store) @@ -1532,7 +1720,7 @@ pub struct WasmList { } impl WasmList { - fn new( + pub(crate) fn new( ptr: usize, len: usize, cx: &mut LiftContext<'_>, diff --git a/crates/wasmtime/src/runtime/component/instance.rs b/crates/wasmtime/src/runtime/component/instance.rs index 611822aa6e87..cbff7714cd66 100644 --- a/crates/wasmtime/src/runtime/component/instance.rs +++ b/crates/wasmtime/src/runtime/component/instance.rs @@ -1,3 +1,4 @@ +use crate::component::concurrent; use crate::component::func::HostFunc; use crate::component::matching::InstanceType; use crate::component::{ @@ -48,7 +49,7 @@ pub(crate) struct InstanceData { // of the component can be thrown away (theoretically). component: Component, - state: OwnedComponentInstance, + pub(crate) state: OwnedComponentInstance, /// Arguments that this instance used to be instantiated. /// @@ -830,12 +831,24 @@ impl InstancePre { where T: Send, { - let mut store = store.as_context_mut(); + let store = store.as_context_mut(); assert!( store.0.async_support(), "must use sync instantiation when async support is disabled" ); - store.on_fiber(|store| self.instantiate_impl(store)).await? + #[cfg(feature = "component-model-async")] + { + // TODO: do we need to return the store here due to the possible + // invalidation of the reference we were passed? + concurrent::on_fiber(store, None, move |store| self.instantiate_impl(store)) + .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let mut store = store; + store.on_fiber(|store| self.instantiate_impl(store)).await? + } } fn instantiate_impl(&self, mut store: impl AsContextMut) -> Result { diff --git a/crates/wasmtime/src/runtime/component/linker.rs b/crates/wasmtime/src/runtime/component/linker.rs index 1a0cd4da4ae7..bd880104e0ee 100644 --- a/crates/wasmtime/src/runtime/component/linker.rs +++ b/crates/wasmtime/src/runtime/component/linker.rs @@ -382,7 +382,7 @@ impl LinkerInstance<'_, T> { } } - /// Defines a new host-provided function into this [`Linker`]. + /// Defines a new host-provided function into this [`LinkerInstance`]. /// /// This method is used to give host functions to wasm components. The /// `func` provided will be callable from linked components with the type @@ -404,13 +404,13 @@ impl LinkerInstance<'_, T> { where F: Fn(StoreContextMut, Params) -> Result + Send + Sync + 'static, Params: ComponentNamedList + Lift + 'static, - Return: ComponentNamedList + Lower + 'static, + Return: ComponentNamedList + Lower + Send + Sync + 'static, { self.insert(name, Definition::Func(HostFunc::from_closure(func)))?; Ok(()) } - /// Defines a new host-provided async function into this [`Linker`]. + /// Defines a new host-provided async function into this [`LinkerInstance`]. /// /// This is exactly like [`Self::func_wrap`] except it takes an async /// host function. @@ -425,16 +425,26 @@ impl LinkerInstance<'_, T> { + Sync + 'static, Params: ComponentNamedList + Lift + 'static, - Return: ComponentNamedList + Lower + 'static, + Return: ComponentNamedList + Lower + Send + Sync + 'static, { assert!( self.engine.config().async_support, "cannot use `func_wrap_async` without enabling async support in the config" ); + let ff = move |mut store: StoreContextMut<'_, T>, params: Params| -> Result { - let async_cx = store.as_context_mut().0.async_cx().expect("async cx"); - let mut future = Pin::from(f(store.as_context_mut(), params)); - unsafe { async_cx.block_on(future.as_mut()) }? + #[cfg(feature = "component-model-async")] + { + let async_cx = crate::component::concurrent::AsyncCx::new(&mut store); + let mut future = Pin::from(f(store.as_context_mut(), params)); + unsafe { async_cx.block_on::(future.as_mut(), None) }?.0 + } + #[cfg(not(feature = "component-model-async"))] + { + let async_cx = store.as_context_mut().0.async_cx().expect("async cx"); + let mut future = Pin::from(f(store.as_context_mut(), params)); + unsafe { async_cx.block_on(future.as_mut()) }? + } }; self.func_wrap(name, ff) } @@ -470,8 +480,8 @@ impl LinkerInstance<'_, T> { self.engine.config().async_support, "cannot use `func_wrap_concurrent` without enabling async support in the config" ); - _ = (name, f); - todo!() + self.insert(name, Definition::Func(HostFunc::from_concurrent(f)))?; + Ok(()) } /// Define a new host-provided function using dynamically typed values. @@ -603,13 +613,60 @@ impl LinkerInstance<'_, T> { "cannot use `func_new_async` without enabling async support in the config" ); let ff = move |mut store: StoreContextMut<'_, T>, params: &[Val], results: &mut [Val]| { - let async_cx = store.as_context_mut().0.async_cx().expect("async cx"); - let mut future = Pin::from(f(store.as_context_mut(), params, results)); - unsafe { async_cx.block_on(future.as_mut()) }? + #[cfg(feature = "component-model-async")] + { + let async_cx = crate::component::concurrent::AsyncCx::new(&mut store); + let mut future = Pin::from(f(store.as_context_mut(), params, results)); + unsafe { async_cx.block_on::(future.as_mut(), None) }?.0 + } + #[cfg(not(feature = "component-model-async"))] + { + let async_cx = store.as_context_mut().0.async_cx().expect("async cx"); + let mut future = Pin::from(f(store.as_context_mut(), params, results)); + unsafe { async_cx.block_on(future.as_mut()) }? + } }; self.func_new(name, ff) } + /// Define a new host-provided async function using dynamic types. + /// + /// This allows the caller to register host functions with the + /// `LinkerInstance` such that multiple calls to such functions can run + /// concurrently. This isn't possible with the existing func_wrap_async + /// method because it takes a function which returns a future that owns a + /// unique reference to the Store, meaning the Store can't be used for + /// anything else until the future resolves. + /// + /// Ideally, we'd have a way to thread a `StoreContextMut` through an + /// arbitrary `Future` such that it has access to the `Store` only while + /// being polled (i.e. between, but not across, await points). However, + /// there's currently no way to express that in async Rust, so we make do + /// with a more awkward scheme: each function registered using + /// `func_wrap_concurrent` gets access to the `Store` twice: once before + /// doing any concurrent operations (i.e. before awaiting) and once + /// afterward. This allows multiple calls to proceed concurrently without + /// any one of them monopolizing the store. + #[cfg(feature = "component-model-async")] + pub fn func_new_concurrent(&mut self, name: &str, f: F) -> Result<()> + where + N: FnOnce(StoreContextMut) -> Result> + Send + Sync + 'static, + FN: Future + Send + Sync + 'static, + F: Fn(StoreContextMut, Vec) -> FN + Send + Sync + 'static, + { + assert!( + self.engine.config().async_support, + "cannot use `func_wrap_concurrent` without enabling async support in the config" + ); + self.insert( + name, + Definition::Func(HostFunc::new_dynamic_concurrent(move |store, params, _| { + f(store, params) + })), + )?; + Ok(()) + } + /// Defines a [`Module`] within this instance. /// /// This can be used to provide a core wasm [`Module`] as an import to a @@ -675,11 +732,21 @@ impl LinkerInstance<'_, T> { let dtor = Arc::new(crate::func::HostFunc::wrap_inner( &self.engine, move |mut cx: crate::Caller<'_, T>, (param,): (u32,)| { - let async_cx = cx.as_context_mut().0.async_cx().expect("async cx"); - let mut future = Pin::from(dtor(cx.as_context_mut(), param)); - match unsafe { async_cx.block_on(future.as_mut()) } { - Ok(Ok(())) => Ok(()), - Ok(Err(trap)) | Err(trap) => Err(trap), + #[cfg(feature = "component-model-async")] + { + let async_cx = + crate::component::concurrent::AsyncCx::new(&mut cx.as_context_mut()); + let mut future = Pin::from(dtor(cx.as_context_mut(), param)); + unsafe { async_cx.block_on(future.as_mut(), None::>) }?.0 + } + #[cfg(not(feature = "component-model-async"))] + { + let async_cx = cx.as_context_mut().0.async_cx().expect("async cx"); + let mut future = Pin::from(dtor(cx.as_context_mut(), param)); + match unsafe { async_cx.block_on(future.as_mut()) } { + Ok(Ok(())) => Ok(()), + Ok(Err(trap)) | Err(trap) => Err(trap), + } } }, )); diff --git a/crates/wasmtime/src/runtime/component/matching.rs b/crates/wasmtime/src/runtime/component/matching.rs index 4222daa6dc62..d6cac001cb9a 100644 --- a/crates/wasmtime/src/runtime/component/matching.rs +++ b/crates/wasmtime/src/runtime/component/matching.rs @@ -1,5 +1,6 @@ use crate::component::func::HostFunc; use crate::component::linker::{Definition, Strings}; +use crate::component::types::{FutureType, StreamType}; use crate::component::ResourceType; use crate::prelude::*; use crate::runtime::vm::component::ComponentInstance; @@ -9,7 +10,7 @@ use alloc::sync::Arc; use core::any::Any; use wasmtime_environ::component::{ ComponentTypes, NameMap, ResourceIndex, TypeComponentInstance, TypeDef, TypeFuncIndex, - TypeModule, TypeResourceTableIndex, + TypeFutureTableIndex, TypeModule, TypeResourceTableIndex, TypeStreamTableIndex, }; use wasmtime_environ::PrimaryMap; @@ -199,6 +200,14 @@ impl<'a> InstanceType<'a> { .copied() .unwrap_or_else(|| ResourceType::uninstantiated(&self.types, index)) } + + pub fn future_type(&self, index: TypeFutureTableIndex) -> FutureType { + FutureType::from(self.types[index].ty, self) + } + + pub fn stream_type(&self, index: TypeStreamTableIndex) -> StreamType { + StreamType::from(self.types[index].ty, self) + } } /// Small helper method to downcast an `Arc` borrow into a borrow of a concrete diff --git a/crates/wasmtime/src/runtime/component/mod.rs b/crates/wasmtime/src/runtime/component/mod.rs index 49199d1a0c21..053a628af8ee 100644 --- a/crates/wasmtime/src/runtime/component/mod.rs +++ b/crates/wasmtime/src/runtime/component/mod.rs @@ -116,7 +116,8 @@ mod values; pub use self::component::{Component, ComponentExportIndex}; #[cfg(feature = "component-model-async")] pub use self::concurrent::{ - ErrorContext, FutureReader, Promise, PromisesUnordered, StreamReader, VMComponentAsyncStore, + for_any, future, stream, ErrorContext, FutureReader, FutureWriter, Promise, PromisesUnordered, + StreamReader, StreamWriter, VMComponentAsyncStore, }; pub use self::func::{ ComponentNamedList, ComponentType, Func, Lift, Lower, TypedFunc, WasmList, WasmStr, @@ -674,3 +675,478 @@ pub mod bindgen_examples; #[cfg(not(any(docsrs, test, doctest)))] #[doc(hidden)] pub mod bindgen_examples {} + +#[cfg(not(feature = "component-model-async"))] +pub(crate) mod concurrent { + use { + crate::{ + component::{ + func::{ComponentType, LiftContext, LowerContext}, + Val, + }, + vm::{VMFuncRef, VMMemoryDefinition, VMOpaqueContext}, + AsContextMut, StoreContextMut, ValRaw, + }, + alloc::{sync::Arc, task::Wake}, + anyhow::Result, + core::{ + future::Future, + marker::PhantomData, + mem::MaybeUninit, + pin::pin, + task::{Context, Poll, Waker}, + }, + wasmtime_environ::component::{ + InterfaceType, RuntimeComponentInstanceIndex, TypeComponentLocalErrorContextTableIndex, + TypeFutureTableIndex, TypeStreamTableIndex, TypeTaskReturnIndex, + }, + }; + + pub fn for_any(fun: F) -> F + where + F: FnOnce(StoreContextMut) -> R + 'static, + R: 'static, + { + fun + } + + fn dummy_waker() -> Waker { + struct DummyWaker; + + impl Wake for DummyWaker { + fn wake(self: Arc) {} + } + + Arc::new(DummyWaker).into() + } + + pub(crate) fn poll_and_block<'a, T, R: Send + Sync + 'static>( + mut store: StoreContextMut<'a, T>, + future: impl Future) -> Result + 'static> + + Send + + Sync + + 'static, + _caller_instance: RuntimeComponentInstanceIndex, + ) -> Result<(R, StoreContextMut<'a, T>)> { + match pin!(future).poll(&mut Context::from_waker(&dummy_waker())) { + Poll::Ready(fun) => { + let result = fun(store.as_context_mut())?; + Ok((result, store)) + } + Poll::Pending => { + unreachable!() + } + } + } + + pub(crate) extern "C" fn task_backpressure( + _cx: *mut VMOpaqueContext, + _caller_instance: RuntimeComponentInstanceIndex, + _enabled: u32, + ) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn task_return( + _cx: *mut VMOpaqueContext, + _ty: TypeTaskReturnIndex, + _storage: *mut MaybeUninit, + _storage_len: usize, + ) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn task_wait( + _cx: *mut VMOpaqueContext, + _caller_instance: RuntimeComponentInstanceIndex, + _async_: bool, + _memory: *mut VMMemoryDefinition, + _payload: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn task_poll( + _cx: *mut VMOpaqueContext, + _caller_instance: RuntimeComponentInstanceIndex, + _async_: bool, + _memory: *mut VMMemoryDefinition, + _payload: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn task_yield(_cx: *mut VMOpaqueContext, _async_: bool) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn subtask_drop( + _cx: *mut VMOpaqueContext, + _caller_instance: RuntimeComponentInstanceIndex, + _task_id: u32, + ) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn async_enter( + _cx: *mut VMOpaqueContext, + _start: *mut VMFuncRef, + _return_: *mut VMFuncRef, + _caller_instance: RuntimeComponentInstanceIndex, + _task_return_type: TypeTaskReturnIndex, + _params: u32, + _results: u32, + ) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn async_exit( + _cx: *mut VMOpaqueContext, + _callback: *mut VMFuncRef, + _post_return: *mut VMFuncRef, + _caller_instance: RuntimeComponentInstanceIndex, + _callee: *mut VMFuncRef, + _callee_instance: RuntimeComponentInstanceIndex, + _param_count: u32, + _result_count: u32, + _flags: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn future_new( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn future_write( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: u8, + _ty: TypeFutureTableIndex, + _future: u32, + _address: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn future_read( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: u8, + _ty: TypeFutureTableIndex, + _future: u32, + _address: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn future_cancel_write( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + _async_: bool, + _writer: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn future_cancel_read( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + _async_: bool, + _reader: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn future_close_writable( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + _writer: u32, + _error: u32, + ) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn future_close_readable( + _vmctx: *mut VMOpaqueContext, + _ty: TypeFutureTableIndex, + _reader: u32, + ) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn stream_new( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn stream_write( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: u8, + _ty: TypeStreamTableIndex, + _stream: u32, + _address: u32, + _count: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn stream_read( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: u8, + _ty: TypeStreamTableIndex, + _stream: u32, + _address: u32, + _count: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn stream_cancel_write( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + _async_: bool, + _writer: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn stream_cancel_read( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + _async_: bool, + _reader: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn stream_close_writable( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + _writer: u32, + _error: u32, + ) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn stream_close_readable( + _vmctx: *mut VMOpaqueContext, + _ty: TypeStreamTableIndex, + _reader: u32, + ) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn flat_stream_write( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _ty: TypeStreamTableIndex, + _payload_size: u32, + _payload_align: u32, + _stream: u32, + _address: u32, + _count: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn flat_stream_read( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _ty: TypeStreamTableIndex, + _payload_size: u32, + _payload_align: u32, + _stream: u32, + _address: u32, + _count: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn error_context_new( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: u8, + _ty: TypeComponentLocalErrorContextTableIndex, + _address: u32, + _count: u32, + ) -> u64 { + unreachable!() + } + + pub(crate) extern "C" fn error_context_debug_message( + _vmctx: *mut VMOpaqueContext, + _memory: *mut VMMemoryDefinition, + _realloc: *mut VMFuncRef, + _string_encoding: u8, + _ty: TypeComponentLocalErrorContextTableIndex, + _handle: u32, + _address: u32, + ) -> bool { + unreachable!() + } + + pub(crate) extern "C" fn error_context_drop( + _vmctx: *mut VMOpaqueContext, + _ty: TypeComponentLocalErrorContextTableIndex, + _error: u32, + ) -> bool { + unreachable!() + } + + pub struct ErrorContext; + + impl ErrorContext { + pub(crate) fn new(_rep: u32) -> Self { + unreachable!() + } + + pub(crate) fn into_val(self) -> Val { + unreachable!() + } + + pub(crate) fn lower( + &self, + _cx: &mut LowerContext<'_, T>, + _ty: InterfaceType, + _dst: &mut MaybeUninit<::Lower>, + ) -> Result<()> { + unreachable!() + } + + pub(crate) fn store( + &self, + _cx: &mut LowerContext<'_, T>, + _ty: InterfaceType, + _offset: usize, + ) -> Result<()> { + unreachable!() + } + + pub(crate) fn lift( + _cx: &mut LiftContext<'_>, + _ty: InterfaceType, + _src: &::Lower, + ) -> Result { + unreachable!() + } + + pub(crate) fn load( + _cx: &mut LiftContext<'_>, + _ty: InterfaceType, + _bytes: &[u8], + ) -> Result { + unreachable!() + } + } + + pub struct StreamReader

{ + _phantom: PhantomData

, + } + + impl

StreamReader

{ + pub(crate) fn new(_rep: u32) -> Self { + unreachable!() + } + + pub(crate) fn into_val(self) -> Val { + unreachable!() + } + + pub(crate) fn lower( + &self, + _cx: &mut LowerContext<'_, T>, + _ty: InterfaceType, + _dst: &mut MaybeUninit<::Lower>, + ) -> Result<()> { + unreachable!() + } + + pub(crate) fn store( + &self, + _cx: &mut LowerContext<'_, T>, + _ty: InterfaceType, + _offset: usize, + ) -> Result<()> { + unreachable!() + } + + pub(crate) fn lift( + _cx: &mut LiftContext<'_>, + _ty: InterfaceType, + _src: &::Lower, + ) -> Result { + unreachable!() + } + + pub(crate) fn load( + _cx: &mut LiftContext<'_>, + _ty: InterfaceType, + _bytes: &[u8], + ) -> Result { + unreachable!() + } + } + + pub struct FutureReader

{ + _phantom: PhantomData

, + } + + impl

FutureReader

{ + pub(crate) fn new(_rep: u32) -> Self { + unreachable!() + } + + pub(crate) fn into_val(self) -> Val { + unreachable!() + } + + pub(crate) fn lower( + &self, + _cx: &mut LowerContext<'_, T>, + _ty: InterfaceType, + _dst: &mut MaybeUninit<::Lower>, + ) -> Result<()> { + unreachable!() + } + + pub(crate) fn store( + &self, + _cx: &mut LowerContext<'_, T>, + _ty: InterfaceType, + _offset: usize, + ) -> Result<()> { + unreachable!() + } + + pub(crate) fn lift( + _cx: &mut LiftContext<'_>, + _ty: InterfaceType, + _src: &::Lower, + ) -> Result { + unreachable!() + } + + pub(crate) fn load( + _cx: &mut LiftContext<'_>, + _ty: InterfaceType, + _bytes: &[u8], + ) -> Result { + unreachable!() + } + } +} diff --git a/crates/wasmtime/src/runtime/component/storage.rs b/crates/wasmtime/src/runtime/component/storage.rs index 01537b42984d..25e43da8c688 100644 --- a/crates/wasmtime/src/runtime/component/storage.rs +++ b/crates/wasmtime/src/runtime/component/storage.rs @@ -37,7 +37,32 @@ pub unsafe fn slice_to_storage_mut(slice: &mut [MaybeUninit]) -> &mut // stay within the bounds of the number of actual values given rather than // reading past the end of an array. This shouldn't actually trip unless // there's a bug in Wasmtime though. - assert!(mem::size_of_val(slice) >= mem::size_of::()); + assert!( + mem::size_of_val(slice) >= mem::size_of::(), + "needed {}; got {}", + mem::size_of::(), + mem::size_of_val(slice) + ); &mut *slice.as_mut_ptr().cast() } + +/// Same as `storage_as_slice`, but in reverse +#[cfg(feature = "component-model-async")] +pub unsafe fn slice_to_storage(slice: &[ValRaw]) -> &T { + assert_raw_slice_compat::(); + + // This is an actual runtime assertion which if performance calls for we may + // need to relax to a debug assertion. This notably tries to ensure that we + // stay within the bounds of the number of actual values given rather than + // reading past the end of an array. This shouldn't actually trip unless + // there's a bug in Wasmtime though. + assert!( + mem::size_of_val(slice) >= mem::size_of::(), + "needed {}; got {}", + mem::size_of::(), + mem::size_of_val(slice) + ); + + &*slice.as_ptr().cast() +} diff --git a/crates/wasmtime/src/runtime/component/types.rs b/crates/wasmtime/src/runtime/component/types.rs index 548143f43dfe..028b624ea33b 100644 --- a/crates/wasmtime/src/runtime/component/types.rs +++ b/crates/wasmtime/src/runtime/component/types.rs @@ -7,9 +7,9 @@ use core::fmt; use core::ops::Deref; use wasmtime_environ::component::{ ComponentTypes, InterfaceType, ResourceIndex, TypeComponentIndex, TypeComponentInstanceIndex, - TypeDef, TypeEnumIndex, TypeFlagsIndex, TypeFuncIndex, TypeListIndex, TypeModuleIndex, - TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, TypeResultIndex, TypeTupleIndex, - TypeVariantIndex, + TypeDef, TypeEnumIndex, TypeFlagsIndex, TypeFuncIndex, TypeFutureIndex, TypeFutureTableIndex, + TypeListIndex, TypeModuleIndex, TypeOptionIndex, TypeRecordIndex, TypeResourceTableIndex, + TypeResultIndex, TypeStreamIndex, TypeStreamTableIndex, TypeTupleIndex, TypeVariantIndex, }; use wasmtime_environ::PrimaryMap; @@ -145,9 +145,16 @@ impl TypeChecker<'_> { (InterfaceType::String, _) => false, (InterfaceType::Char, InterfaceType::Char) => true, (InterfaceType::Char, _) => false, - (InterfaceType::Future(_), _) - | (InterfaceType::Stream(_), _) - | (InterfaceType::ErrorContext(_), _) => todo!(), + (InterfaceType::Future(t1), InterfaceType::Future(t2)) => { + self.future_table_types_equal(t1, t2) + } + (InterfaceType::Future(_), _) => false, + (InterfaceType::Stream(t1), InterfaceType::Stream(t2)) => { + self.stream_table_types_equal(t1, t2) + } + (InterfaceType::Stream(_), _) => false, + (InterfaceType::ErrorContext(_), InterfaceType::ErrorContext(_)) => true, + (InterfaceType::ErrorContext(_), _) => false, } } @@ -247,6 +254,34 @@ impl TypeChecker<'_> { let b = &self.b_types[f2]; a.names == b.names } + + fn future_table_types_equal(&self, t1: TypeFutureTableIndex, t2: TypeFutureTableIndex) -> bool { + self.futures_equal(self.a_types[t1].ty, self.b_types[t2].ty) + } + + fn futures_equal(&self, t1: TypeFutureIndex, t2: TypeFutureIndex) -> bool { + let a = &self.a_types[t1]; + let b = &self.b_types[t2]; + match (a.payload, b.payload) { + (Some(t1), Some(t2)) => self.interface_types_equal(t1, t2), + (None, None) => true, + _ => false, + } + } + + fn stream_table_types_equal(&self, t1: TypeStreamTableIndex, t2: TypeStreamTableIndex) -> bool { + self.streams_equal(self.a_types[t1].ty, self.b_types[t2].ty) + } + + fn streams_equal(&self, t1: TypeStreamIndex, t2: TypeStreamIndex) -> bool { + let a = &self.a_types[t1]; + let b = &self.b_types[t2]; + match (a.payload, b.payload) { + (Some(t1), Some(t2)) => self.interface_types_equal(t1, t2), + (None, None) => true, + _ => false, + } + } } /// A `list` interface type @@ -419,7 +454,7 @@ impl PartialEq for OptionType { impl Eq for OptionType {} -/// An `expected` interface type +/// A `result` interface type #[derive(Clone, Debug)] pub struct ResultType(Handle); @@ -479,6 +514,58 @@ impl PartialEq for Flags { impl Eq for Flags {} +/// An `future` interface type +#[derive(Clone, Debug)] +pub struct FutureType(Handle); + +impl FutureType { + pub(crate) fn from(index: TypeFutureIndex, ty: &InstanceType<'_>) -> Self { + FutureType(Handle::new(index, ty)) + } + + /// Retrieve the type parameter for this `future`. + pub fn ty(&self) -> Option { + Some(Type::from( + self.0.types[self.0.index].payload.as_ref()?, + &self.0.instance(), + )) + } +} + +impl PartialEq for FutureType { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::futures_equal) + } +} + +impl Eq for FutureType {} + +/// An `stream` interface type +#[derive(Clone, Debug)] +pub struct StreamType(Handle); + +impl StreamType { + pub(crate) fn from(index: TypeStreamIndex, ty: &InstanceType<'_>) -> Self { + StreamType(Handle::new(index, ty)) + } + + /// Retrieve the type parameter for this `stream`. + pub fn ty(&self) -> Option { + Some(Type::from( + self.0.types[self.0.index].payload.as_ref()?, + &self.0.instance(), + )) + } +} + +impl PartialEq for StreamType { + fn eq(&self, other: &Self) -> bool { + self.0.equivalent(&other.0, TypeChecker::streams_equal) + } +} + +impl Eq for StreamType {} + /// Represents a component model interface type #[derive(Clone, PartialEq, Eq, Debug)] #[allow(missing_docs)] @@ -506,6 +593,9 @@ pub enum Type { Flags(Flags), Own(ResourceType), Borrow(ResourceType), + Future(FutureType), + Stream(StreamType), + ErrorContext, } impl Type { @@ -663,9 +753,9 @@ impl Type { InterfaceType::Flags(index) => Type::Flags(Flags::from(*index, instance)), InterfaceType::Own(index) => Type::Own(instance.resource_type(*index)), InterfaceType::Borrow(index) => Type::Borrow(instance.resource_type(*index)), - InterfaceType::Future(_) - | InterfaceType::Stream(_) - | InterfaceType::ErrorContext(_) => todo!(), + InterfaceType::Future(index) => Type::Future(instance.future_type(*index)), + InterfaceType::Stream(index) => Type::Stream(instance.stream_type(*index)), + InterfaceType::ErrorContext(_) => Type::ErrorContext, } } @@ -694,6 +784,9 @@ impl Type { Type::Flags(_) => "flags", Type::Own(_) => "own", Type::Borrow(_) => "borrow", + Type::Future(_) => "future", + Type::Stream(_) => "stream", + Type::ErrorContext => "error-context", } } } diff --git a/crates/wasmtime/src/runtime/component/values.rs b/crates/wasmtime/src/runtime/component/values.rs index cccbaf3ea609..8b78137d6de2 100644 --- a/crates/wasmtime/src/runtime/component/values.rs +++ b/crates/wasmtime/src/runtime/component/values.rs @@ -1,3 +1,4 @@ +use crate::component::concurrent::{ErrorContext, FutureReader, StreamReader}; use crate::component::func::{desc, Lift, LiftContext, Lower, LowerContext}; use crate::component::ResourceAny; use crate::prelude::*; @@ -86,6 +87,9 @@ pub enum Val { Result(Result>, Option>>), Flags(Vec), Resource(ResourceAny), + Future(FutureAny), + Stream(StreamAny), + ErrorContext(ErrorContextAny), } impl Val { @@ -198,9 +202,9 @@ impl Val { Val::Flags(flags.into()) } - InterfaceType::Future(_) - | InterfaceType::Stream(_) - | InterfaceType::ErrorContext(_) => todo!(), + InterfaceType::Future(_) => FutureReader::<()>::lift(cx, ty, next(src))?.into_val(), + InterfaceType::Stream(_) => StreamReader::<()>::lift(cx, ty, next(src))?.into_val(), + InterfaceType::ErrorContext(_) => ErrorContext::lift(cx, ty, next(src))?.into_val(), }) } @@ -322,9 +326,9 @@ impl Val { } Val::Flags(flags.into()) } - InterfaceType::Future(_) - | InterfaceType::Stream(_) - | InterfaceType::ErrorContext(_) => todo!(), + InterfaceType::Future(_) => FutureReader::<()>::load(cx, ty, bytes)?.into_val(), + InterfaceType::Stream(_) => StreamReader::<()>::load(cx, ty, bytes)?.into_val(), + InterfaceType::ErrorContext(_) => ErrorContext::load(cx, ty, bytes)?.into_val(), }) } @@ -435,9 +439,18 @@ impl Val { Ok(()) } (InterfaceType::Flags(_), _) => unexpected(ty, self), - (InterfaceType::Future(_), _) - | (InterfaceType::Stream(_), _) - | (InterfaceType::ErrorContext(_), _) => todo!(), + (InterfaceType::Future(_), Val::Future(FutureAny(rep))) => { + FutureReader::<()>::new(*rep).lower(cx, ty, next_mut(dst)) + } + (InterfaceType::Future(_), _) => unexpected(ty, self), + (InterfaceType::Stream(_), Val::Stream(StreamAny(rep))) => { + StreamReader::<()>::new(*rep).lower(cx, ty, next_mut(dst)) + } + (InterfaceType::Stream(_), _) => unexpected(ty, self), + (InterfaceType::ErrorContext(_), Val::ErrorContext(ErrorContextAny(rep))) => { + ErrorContext::new(*rep).lower(cx, ty, next_mut(dst)) + } + (InterfaceType::ErrorContext(_), _) => unexpected(ty, self), } } @@ -573,13 +586,22 @@ impl Val { Ok(()) } (InterfaceType::Flags(_), _) => unexpected(ty, self), - (InterfaceType::Future(_), _) - | (InterfaceType::Stream(_), _) - | (InterfaceType::ErrorContext(_), _) => todo!(), + (InterfaceType::Future(_), Val::Future(FutureAny(rep))) => { + FutureReader::<()>::new(*rep).store(cx, ty, offset) + } + (InterfaceType::Future(_), _) => unexpected(ty, self), + (InterfaceType::Stream(_), Val::Stream(StreamAny(rep))) => { + StreamReader::<()>::new(*rep).store(cx, ty, offset) + } + (InterfaceType::Stream(_), _) => unexpected(ty, self), + (InterfaceType::ErrorContext(_), Val::ErrorContext(ErrorContextAny(rep))) => { + ErrorContext::new(*rep).store(cx, ty, offset) + } + (InterfaceType::ErrorContext(_), _) => unexpected(ty, self), } } - fn desc(&self) -> &'static str { + pub(crate) fn desc(&self) -> &'static str { match self { Val::Bool(_) => "bool", Val::U8(_) => "u8", @@ -603,6 +625,9 @@ impl Val { Val::Result(_) => "result", Val::Resource(_) => "resource", Val::Flags(_) => "flags", + Val::Future(_) => "future", + Val::Stream(_) => "stream", + Val::ErrorContext(_) => "error-context", } } @@ -681,6 +706,12 @@ impl PartialEq for Val { (Self::Flags(_), _) => false, (Self::Resource(l), Self::Resource(r)) => l == r, (Self::Resource(_), _) => false, + (Self::Future(l), Self::Future(r)) => l == r, + (Self::Future(_), _) => false, + (Self::Stream(l), Self::Stream(r)) => l == r, + (Self::Stream(_), _) => false, + (Self::ErrorContext(l), Self::ErrorContext(r)) => l == r, + (Self::ErrorContext(_), _) => false, } } } @@ -1000,3 +1031,12 @@ fn unexpected(ty: InterfaceType, val: &Val) -> Result { val.desc() ) } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FutureAny(pub(crate) u32); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamAny(pub(crate) u32); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ErrorContextAny(pub(crate) u32); diff --git a/crates/wasmtime/src/runtime/externals/table.rs b/crates/wasmtime/src/runtime/externals/table.rs index fd404859e7de..ed7b57e27840 100644 --- a/crates/wasmtime/src/runtime/externals/table.rs +++ b/crates/wasmtime/src/runtime/externals/table.rs @@ -95,14 +95,26 @@ impl Table { where T: Send, { - let mut store = store.as_context_mut(); + let store = store.as_context_mut(); assert!( store.0.async_support(), "cannot use `new_async` without enabling async support on the config" ); - store - .on_fiber(|store| Table::_new(store.0, ty, init)) + #[cfg(feature = "component-model-async")] + { + crate::component::concurrent::on_fiber(store, None, move |store| { + Table::_new(store.0, ty, init) + }) .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let mut store = store; + store + .on_fiber(|store| Table::_new(store.0, ty, init)) + .await? + } } fn _new(store: &mut StoreOpaque, ty: TableType, init: Ref) -> Result { @@ -289,14 +301,26 @@ impl Table { where T: Send, { - let mut store = store.as_context_mut(); + let store = store.as_context_mut(); assert!( store.0.async_support(), "cannot use `grow_async` without enabling async support on the config" ); - store - .on_fiber(|store| self.grow(store, delta, init)) + #[cfg(feature = "component-model-async")] + { + crate::component::concurrent::on_fiber(store, None, move |store| { + self.grow(store, delta, init) + }) .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let mut store = store; + store + .on_fiber(|store| self.grow(store, delta, init)) + .await? + } } /// Copy `len` elements from `src_table[src_index..]` into diff --git a/crates/wasmtime/src/runtime/func.rs b/crates/wasmtime/src/runtime/func.rs index fee813f446de..3c0ed39965e3 100644 --- a/crates/wasmtime/src/runtime/func.rs +++ b/crates/wasmtime/src/runtime/func.rs @@ -557,16 +557,29 @@ impl Func { ); assert!(ty.comes_from_same_engine(store.as_context().engine())); Func::new(store, ty, move |mut caller, params, results| { - let async_cx = caller - .store - .as_context_mut() - .0 - .async_cx() - .expect("Attempt to spawn new action on dying fiber"); - let mut future = Pin::from(func(caller, params, results)); - match unsafe { async_cx.block_on(future.as_mut()) } { - Ok(Ok(())) => Ok(()), - Ok(Err(trap)) | Err(trap) => Err(trap), + #[cfg(feature = "component-model-async")] + { + let async_cx = + crate::component::concurrent::AsyncCx::new(&mut caller.store.as_context_mut()); + let mut future = Pin::from(func(caller, params, results)); + match unsafe { async_cx.block_on::(future.as_mut(), None) } { + Ok((Ok(()), _)) => Ok(()), + Ok((Err(trap), _)) | Err(trap) => Err(trap), + } + } + #[cfg(not(feature = "component-model-async"))] + { + let async_cx = caller + .store + .as_context_mut() + .0 + .async_cx() + .expect("Attempt to spawn new action on dying fiber"); + let mut future = Pin::from(func(caller, params, results)); + match unsafe { async_cx.block_on(future.as_mut()) } { + Ok(Ok(())) => Ok(()), + Ok(Err(trap)) | Err(trap) => Err(trap), + } } }) } @@ -875,17 +888,31 @@ impl Func { concat!("cannot use `wrap_async` without enabling async support on the config") ); Func::wrap_inner(store, move |mut caller: Caller<'_, T>, args| { - let async_cx = caller - .store - .as_context_mut() - .0 - .async_cx() - .expect("Attempt to start async function on dying fiber"); - let mut future = Pin::from(func(caller, args)); - - match unsafe { async_cx.block_on(future.as_mut()) } { - Ok(ret) => ret.into_fallible(), - Err(e) => R::fallible_from_error(e), + #[cfg(feature = "component-model-async")] + { + let async_cx = + crate::component::concurrent::AsyncCx::new(&mut caller.store.as_context_mut()); + let mut future = Pin::from(func(caller, args)); + + match unsafe { async_cx.block_on::(future.as_mut(), None) } { + Ok((ret, _)) => ret.into_fallible(), + Err(e) => R::fallible_from_error(e), + } + } + #[cfg(not(feature = "component-model-async"))] + { + let async_cx = caller + .store + .as_context_mut() + .0 + .async_cx() + .expect("Attempt to start async function on dying fiber"); + let mut future = Pin::from(func(caller, args)); + + match unsafe { async_cx.block_on(future.as_mut()) } { + Ok(ret) => ret.into_fallible(), + Err(e) => R::fallible_from_error(e), + } } }) } @@ -1155,10 +1182,21 @@ impl Func { if need_gc { store.0.gc_async().await; } - let result = store - .on_fiber(|store| unsafe { self.call_impl_do_call(store, params, results) }) - .await??; - Ok(result) + #[cfg(feature = "component-model-async")] + { + crate::component::concurrent::on_fiber(store, None, move |store| unsafe { + self.call_impl_do_call(store, params, results) + }) + .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let result = store + .on_fiber(|store| unsafe { self.call_impl_do_call(store, params, results) }) + .await??; + Ok(result) + } } /// Perform dynamic checks that the arguments given to us match @@ -2350,6 +2388,7 @@ impl HostContext { drop(store); let r = func(caller.sub_caller(), params); + if let Err(trap) = caller.store.0.call_hook(CallHook::ReturningFromHost) { break 'ret R::fallible_from_error(trap); } diff --git a/crates/wasmtime/src/runtime/func/typed.rs b/crates/wasmtime/src/runtime/func/typed.rs index 241b34c2eeb7..db3221b0c64c 100644 --- a/crates/wasmtime/src/runtime/func/typed.rs +++ b/crates/wasmtime/src/runtime/func/typed.rs @@ -132,8 +132,9 @@ where ) -> Result where T: Send, + Results: Send + Sync + 'static, { - let mut store = store.as_context_mut(); + let store = store.as_context_mut(); assert!( store.0.async_support(), "must use `call` with non-async stores" @@ -141,12 +142,25 @@ where if Self::need_gc_before_call_raw(store.0, ¶ms) { store.0.gc_async().await; } - store - .on_fiber(|store| { + #[cfg(feature = "component-model-async")] + { + crate::component::concurrent::on_fiber(store, None, |store| { let func = self.func.vm_func_ref(store.0); unsafe { Self::call_raw(store, &self.ty, func, params) } }) .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let mut store = store; + store + .on_fiber(|store| { + let func = self.func.vm_func_ref(store.0); + unsafe { Self::call_raw(store, &self.ty, func, params) } + }) + .await? + } } #[inline] diff --git a/crates/wasmtime/src/runtime/instance.rs b/crates/wasmtime/src/runtime/instance.rs index 815833503f71..ade249deb7f5 100644 --- a/crates/wasmtime/src/runtime/instance.rs +++ b/crates/wasmtime/src/runtime/instance.rs @@ -227,9 +227,20 @@ impl Instance { "must use sync instantiation when async support is disabled", ); - store - .on_fiber(|store| Self::new_started_impl(store, module, imports)) + #[cfg(feature = "component-model-async")] + { + crate::component::concurrent::on_fiber(store.as_context_mut(), None, move |store| { + Self::new_started_impl(store, module, imports) + }) .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + store + .on_fiber(|store| Self::new_started_impl(store, module, imports)) + .await? + } } /// Internal function to create an instance which doesn't have its `start` diff --git a/crates/wasmtime/src/runtime/linker.rs b/crates/wasmtime/src/runtime/linker.rs index 95cbb548447b..1749ed6452c7 100644 --- a/crates/wasmtime/src/runtime/linker.rs +++ b/crates/wasmtime/src/runtime/linker.rs @@ -459,16 +459,29 @@ impl Linker { ); assert!(ty.comes_from_same_engine(self.engine())); self.func_new(module, name, ty, move |mut caller, params, results| { - let async_cx = caller - .store - .as_context_mut() - .0 - .async_cx() - .expect("Attempt to spawn new function on dying fiber"); - let mut future = Pin::from(func(caller, params, results)); - match unsafe { async_cx.block_on(future.as_mut()) } { - Ok(Ok(())) => Ok(()), - Ok(Err(trap)) | Err(trap) => Err(trap), + #[cfg(feature = "component-model-async")] + { + let async_cx = + crate::component::concurrent::AsyncCx::new(&mut caller.store.as_context_mut()); + let mut future = Pin::from(func(caller, params, results)); + match unsafe { async_cx.block_on::(future.as_mut(), None) } { + Ok((Ok(()), _)) => Ok(()), + Ok((Err(trap), _)) | Err(trap) => Err(trap), + } + } + #[cfg(not(feature = "component-model-async"))] + { + let async_cx = caller + .store + .as_context_mut() + .0 + .async_cx() + .expect("Attempt to spawn new function on dying fiber"); + let mut future = Pin::from(func(caller, params, results)); + match unsafe { async_cx.block_on(future.as_mut()) } { + Ok(Ok(())) => Ok(()), + Ok(Err(trap)) | Err(trap) => Err(trap), + } } }) } @@ -562,16 +575,31 @@ impl Linker { let func = HostFunc::wrap_inner( &self.engine, move |mut caller: Caller<'_, T>, args: Params| { - let async_cx = caller - .store - .as_context_mut() - .0 - .async_cx() - .expect("Attempt to start async function on dying fiber"); - let mut future = Pin::from(func(caller, args)); - match unsafe { async_cx.block_on(future.as_mut()) } { - Ok(ret) => ret.into_fallible(), - Err(e) => Args::fallible_from_error(e), + #[cfg(feature = "component-model-async")] + { + let async_cx = crate::component::concurrent::AsyncCx::new( + &mut caller.store.as_context_mut(), + ); + let mut future = Pin::from(func(caller, args)); + + match unsafe { async_cx.block_on::(future.as_mut(), None) } { + Ok((ret, _)) => ret.into_fallible(), + Err(e) => Args::fallible_from_error(e), + } + } + #[cfg(not(feature = "component-model-async"))] + { + let async_cx = caller + .store + .as_context_mut() + .0 + .async_cx() + .expect("Attempt to start async function on dying fiber"); + let mut future = Pin::from(func(caller, args)); + match unsafe { async_cx.block_on(future.as_mut()) } { + Ok(ret) => ret.into_fallible(), + Err(e) => Args::fallible_from_error(e), + } } }, ); diff --git a/crates/wasmtime/src/runtime/memory.rs b/crates/wasmtime/src/runtime/memory.rs index 498d04966cda..e266dc0d9492 100644 --- a/crates/wasmtime/src/runtime/memory.rs +++ b/crates/wasmtime/src/runtime/memory.rs @@ -261,12 +261,24 @@ impl Memory { where T: Send, { - let mut store = store.as_context_mut(); + let store = store.as_context_mut(); assert!( store.0.async_support(), "cannot use `new_async` without enabling async support on the config" ); - store.on_fiber(|store| Self::_new(store.0, ty)).await? + #[cfg(feature = "component-model-async")] + { + crate::component::concurrent::on_fiber(store, None, move |store| { + Self::_new(store.0, ty) + }) + .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let mut store = store; + store.on_fiber(|store| Self::_new(store.0, ty)).await? + } } /// Helper function for attaching the memory to a "frankenstein" instance @@ -613,12 +625,24 @@ impl Memory { where T: Send, { - let mut store = store.as_context_mut(); + let store = store.as_context_mut(); assert!( store.0.async_support(), "cannot use `grow_async` without enabling async support on the config" ); - store.on_fiber(|store| self.grow(store, delta)).await? + #[cfg(feature = "component-model-async")] + { + crate::component::concurrent::on_fiber(store, None, move |store| { + self.grow(store, delta) + }) + .await? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + let mut store = store; + store.on_fiber(|store| self.grow(store, delta)).await? + } } fn wasmtime_memory(&self, store: &mut StoreOpaque) -> *mut crate::runtime::vm::Memory { diff --git a/crates/wasmtime/src/runtime/store.rs b/crates/wasmtime/src/runtime/store.rs index c881adf4ea56..692d8d8deb49 100644 --- a/crates/wasmtime/src/runtime/store.rs +++ b/crates/wasmtime/src/runtime/store.rs @@ -76,6 +76,8 @@ //! contents of `StoreOpaque`. This is an invariant that we, as the authors of //! `wasmtime`, must uphold for the public interface to be safe. +#[cfg(feature = "component-model-async")] +use crate::component::concurrent; use crate::hash_set::HashSet; use crate::instance::InstanceData; use crate::linker::Definition; @@ -226,6 +228,47 @@ pub struct StoreInner { Option) -> Result + Send + Sync>>, // for comments about `ManuallyDrop`, see `Store::into_data` data: ManuallyDrop, + #[cfg(feature = "component-model-async")] + concurrent_state: concurrent::ConcurrentState, +} + +impl StoreInner { + /// Yields execution to the caller on out-of-gas or epoch interruption. + /// + /// This only works on async futures and stores, and assumes that we're + /// executing on a fiber. This will yield execution back to the caller once. + #[cfg(feature = "async")] + fn async_yield_impl(&mut self) -> Result<()> { + use crate::runtime::vm::Yield; + + let mut future = Yield::new(); + + // When control returns, we have a `Result<()>` passed + // in from the host fiber. If this finished successfully then + // we were resumed normally via a `poll`, so keep going. If + // the future was dropped while we were yielded, then we need + // to clean up this fiber. Do so by raising a trap which will + // abort all wasm and get caught on the other side to clean + // things up. + #[cfg(feature = "component-model-async")] + unsafe { + let async_cx = + crate::component::concurrent::AsyncCx::new(&mut (&mut *self).as_context_mut()); + async_cx + .block_on( + Pin::new_unchecked(&mut future), + None::>, + )? + .0; + Ok(()) + } + #[cfg(not(feature = "component-model-async"))] + unsafe { + self.async_cx() + .expect("attempted to pull async context during shutdown") + .block_on(Pin::new_unchecked(&mut future)) + } + } } enum ResourceLimiterInner { @@ -426,7 +469,9 @@ struct AsyncState { #[derive(Clone, Copy)] struct PollContext { future_context: *mut Context<'static>, + #[cfg_attr(feature = "component-model-async", allow(dead_code))] guard_range_start: *mut u8, + #[cfg_attr(feature = "component-model-async", allow(dead_code))] guard_range_end: *mut u8, } @@ -616,6 +661,8 @@ impl Store { call_hook: None, epoch_deadline_behavior: None, data: ManuallyDrop::new(data), + #[cfg(feature = "component-model-async")] + concurrent_state: Default::default(), }); // Wasmtime uses the callee argument to host functions to learn about @@ -1132,6 +1179,35 @@ impl<'a, T> StoreContextMut<'a, T> { self.0.data_mut() } + #[cfg(feature = "component-model-async")] + pub(crate) fn concurrent_state(&mut self) -> &mut concurrent::ConcurrentState { + self.0.concurrent_state() + } + + pub(crate) fn async_guard_range(&mut self) -> Range<*mut u8> { + #[cfg(feature = "component-model-async")] + { + self.concurrent_state().async_guard_range() + } + #[cfg(not(feature = "component-model-async"))] + { + #[cfg(feature = "async")] + unsafe { + let ptr = self.0.inner.async_state.current_poll_cx.get(); + (*ptr).guard_range_start..(*ptr).guard_range_end + } + #[cfg(not(feature = "async"))] + { + core::ptr::null_mut()..core::ptr::null_mut() + } + } + } + + #[cfg(feature = "component-model-async")] + pub(crate) fn has_pkey(&self) -> bool { + self.0.pkey.is_some() + } + /// Returns the underlying [`Engine`] this store is connected to. pub fn engine(&self) -> &Engine { self.0.engine() @@ -1217,6 +1293,11 @@ impl StoreInner { &mut self.data } + #[cfg(feature = "component-model-async")] + fn concurrent_state(&mut self) -> &mut concurrent::ConcurrentState { + &mut self.concurrent_state + } + #[inline] pub fn call_hook(&mut self, s: CallHook) -> Result<()> { if self.inner.pkey.is_none() && self.call_hook.is_none() { @@ -1256,14 +1337,33 @@ impl StoreInner { #[cfg(all(feature = "async", feature = "call-hook"))] CallHookInner::Async(handler) => unsafe { - self.inner - .async_cx() - .ok_or_else(|| anyhow!("couldn't grab async_cx for call hook"))? - .block_on( - handler - .handle_call_event((&mut *self).as_context_mut(), s) - .as_mut(), - )? + #[cfg(feature = "component-model-async")] + { + let async_cx = crate::component::concurrent::AsyncCx::try_new( + &mut (&mut *self).as_context_mut(), + ) + .ok_or_else(|| anyhow!("couldn't grab async_cx for call hook"))?; + + async_cx + .block_on( + handler + .handle_call_event((&mut *self).as_context_mut(), s) + .as_mut(), + None::>, + )? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + self.inner + .async_cx() + .ok_or_else(|| anyhow!("couldn't grab async_cx for call hook"))? + .block_on( + handler + .handle_call_event((&mut *self).as_context_mut(), s) + .as_mut(), + )? + } }, CallHookInner::ForceTypeParameterToBeUsed { uninhabited, .. } => { @@ -1916,30 +2016,6 @@ impl StoreOpaque { self.set_fuel(self.get_fuel()?) } - /// Yields execution to the caller on out-of-gas or epoch interruption. - /// - /// This only works on async futures and stores, and assumes that we're - /// executing on a fiber. This will yield execution back to the caller once. - #[cfg(feature = "async")] - fn async_yield_impl(&mut self) -> Result<()> { - use crate::runtime::vm::Yield; - - let mut future = Yield::new(); - - // When control returns, we have a `Result<()>` passed - // in from the host fiber. If this finished successfully then - // we were resumed normally via a `poll`, so keep going. If - // the future was dropped while we were yielded, then we need - // to clean up this fiber. Do so by raising a trap which will - // abort all wasm and get caught on the other side to clean - // things up. - unsafe { - self.async_cx() - .expect("attempted to pull async context during shutdown") - .block_on(Pin::new_unchecked(&mut future)) - } - } - #[inline] pub fn signal_handler(&self) -> Option<*const SignalHandler> { let handler = self.signal_handler.as_ref()?; @@ -2134,18 +2210,6 @@ at https://bytecodealliance.org/security. self.num_component_instances += 1; } - pub(crate) fn async_guard_range(&self) -> Range<*mut u8> { - #[cfg(feature = "async")] - unsafe { - let ptr = self.async_state.current_poll_cx.get(); - (*ptr).guard_range_start..(*ptr).guard_range_end - } - #[cfg(not(feature = "async"))] - { - core::ptr::null_mut()..core::ptr::null_mut() - } - } - #[cfg(feature = "async")] fn allocate_fiber_stack(&mut self) -> Result { if let Some(stack) = self.async_state.last_fiber_stack.take() { @@ -2603,14 +2667,35 @@ unsafe impl crate::runtime::vm::VMStore for StoreInner { } #[cfg(feature = "async")] Some(ResourceLimiterInner::Async(ref mut limiter)) => unsafe { - self.inner - .async_cx() - .expect("ResourceLimiterAsync requires async Store") - .block_on( - limiter(&mut self.data) - .memory_growing(current, desired, maximum) - .as_mut(), - )? + #[cfg(feature = "component-model-async")] + { + _ = limiter; + let async_cx = crate::component::concurrent::AsyncCx::new( + &mut (&mut *self).as_context_mut(), + ); + let Some(ResourceLimiterInner::Async(ref mut limiter)) = self.limiter else { + unreachable!(); + }; + async_cx + .block_on::( + limiter(&mut self.data) + .memory_growing(current, desired, maximum) + .as_mut(), + None, + )? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + self.inner + .async_cx() + .expect("ResourceLimiterAsync requires async Store") + .block_on( + limiter(&mut self.data) + .memory_growing(current, desired, maximum) + .as_mut(), + )? + } }, None => Ok(true), } @@ -2641,7 +2726,7 @@ unsafe impl crate::runtime::vm::VMStore for StoreInner { // Need to borrow async_cx before the mut borrow of the limiter. // self.async_cx() panicks when used with a non-async store, so // wrap this in an option. - #[cfg(feature = "async")] + #[cfg(all(feature = "async", not(feature = "component-model-async")))] let async_cx = if self.async_support() && matches!(self.limiter, Some(ResourceLimiterInner::Async(_))) { @@ -2656,13 +2741,34 @@ unsafe impl crate::runtime::vm::VMStore for StoreInner { } #[cfg(feature = "async")] Some(ResourceLimiterInner::Async(ref mut limiter)) => unsafe { - async_cx - .expect("ResourceLimiterAsync requires async Store") - .block_on( - limiter(&mut self.data) - .table_growing(current, desired, maximum) - .as_mut(), - )? + #[cfg(feature = "component-model-async")] + { + _ = limiter; + let async_cx = crate::component::concurrent::AsyncCx::new( + &mut (&mut *self).as_context_mut(), + ); + let Some(ResourceLimiterInner::Async(ref mut limiter)) = self.limiter else { + unreachable!(); + }; + async_cx + .block_on::( + limiter(&mut self.data) + .table_growing(current, desired, maximum) + .as_mut(), + None, + )? + .0 + } + #[cfg(not(feature = "component-model-async"))] + { + async_cx + .expect("ResourceLimiterAsync requires async Store") + .block_on( + limiter(&mut self.data) + .table_growing(current, desired, maximum) + .as_mut(), + )? + } }, None => Ok(true), } diff --git a/crates/wasmtime/src/runtime/vm/component.rs b/crates/wasmtime/src/runtime/vm/component.rs index 06e941b041f7..c59b3a23c7f5 100644 --- a/crates/wasmtime/src/runtime/vm/component.rs +++ b/crates/wasmtime/src/runtime/vm/component.rs @@ -12,6 +12,7 @@ use crate::runtime::vm::{ VMOpaqueContext, VMStore, VMStoreRawPtr, VMWasmCallFunction, ValRaw, VmPtr, VmSafe, }; use alloc::alloc::Layout; +use alloc::collections::BTreeMap; use alloc::sync::Arc; use core::any::Any; use core::marker; @@ -27,10 +28,17 @@ use wasmtime_environ::{HostPtr, PrimaryMap, VMSharedTypeIndex}; // 32-bit platforms const INVALID_PTR: usize = 0xdead_dead_beef_beef_u64 as usize; +mod error_contexts; mod libcalls; mod resources; +mod states; +pub use self::error_contexts::{GlobalErrorContextRefCount, LocalErrorContextRefCount}; pub use self::resources::{CallContexts, ResourceTable, ResourceTables}; +pub use self::states::StateTable; + +#[cfg(feature = "component-model-async")] +pub use self::resources::CallContext; /// Runtime representation of a component instance and all state necessary for /// the instance itself. @@ -58,6 +66,35 @@ pub struct ComponentInstance { /// is how this field is manipulated. component_resource_tables: PrimaryMap, + component_waitable_tables: PrimaryMap>, + + /// (Sub)Component specific error context tracking + /// + /// At the component level, only the number of references (`usize`) to a given error context is tracked, + /// with state related to the error context being held at the component model level, in concurrent + /// state. + /// + /// The state tables in the (sub)component local tracking must contain a pointer into the global + /// error context lookups in order to ensure that in contexts where only the local reference is present + /// the global state can still be maintained/updated. + component_error_context_tables: + PrimaryMap>, + + /// Reference counts for all component error contexts + /// + /// NOTE: it is possible the global ref count to be *greater* than the sum of + /// (sub)component ref counts as tracked by `component_error_context_tables`, for + /// example when the host holds one or more references to error contexts. + /// + /// The key of this primary map is often referred to as the "rep" (i.e. host-side + /// component-wide representation) of the index into concurrent state for a given + /// stored `ErrorContext`. + /// + /// Stated another way, `TypeComponentGlobalErrorContextTableIndex` is essentially the same + /// as a `TableId`. + component_global_error_context_ref_counts: + BTreeMap, + /// Storage for the type information about resources within this component /// instance. /// @@ -86,6 +123,7 @@ pub struct ComponentInstance { /// which this function pointer was registered. /// * `ty` - the type index, relative to the tables in `vmctx`, that is the /// type of the function being called. +/// * `caller_instance` - The (sub)component instance of the caller. /// * `flags` - the component flags for may_enter/leave corresponding to the /// component instance that the lowering happened within. /// * `opt_memory` - this nullable pointer represents the memory configuration @@ -106,7 +144,7 @@ pub struct ComponentInstance { /// or not. On failure this function records trap information in TLS which /// should be suitable for reading later. // -// FIXME: 9 arguments is probably too many. The `data` through `string-encoding` +// FIXME: 11 arguments is probably too many. The `data` through `string-encoding` // parameters should probably get packaged up into the `VMComponentContext`. // Needs benchmarking one way or another though to figure out what the best // balance is here. @@ -114,6 +152,7 @@ pub type VMLoweringCallee = extern "C" fn( vmctx: NonNull, data: NonNull, ty: u32, + caller_instance: u32, flags: NonNull, opt_memory: *mut VMMemoryDefinition, opt_realloc: *mut VMFuncRef, @@ -156,6 +195,37 @@ pub struct VMComponentContext { _marker: marker::PhantomPinned, } +/// Represents the state of a stream or future handle. +#[derive(Debug, Eq, PartialEq)] +pub enum StreamFutureState { + /// Both the read and write ends are owned by the same component instance. + Local, + /// Only the write end is owned by this component instance. + Write, + /// Only the read end is owned by this component instance. + Read, + /// A read or write is in progress. + Busy, +} + +/// Represents the state of a waitable handle. +#[derive(Debug)] +pub enum WaitableState { + /// Represents a task handle. + Task, + /// Represents a stream handle. + Stream(TypeStreamTableIndex, StreamFutureState), + /// Represents a future handle. + Future(TypeFutureTableIndex, StreamFutureState), +} + +/// Represents the state associated with an error context +#[derive(Debug, PartialEq, Eq, PartialOrd)] +pub struct ErrorContextState { + /// Debug message associated with the error context + pub(crate) debug_msg: String, +} + impl ComponentInstance { /// Converts the `vmctx` provided into a `ComponentInstance` and runs the /// provided closure with that instance. @@ -205,12 +275,30 @@ impl ComponentInstance { ) { assert!(alloc_size >= Self::alloc_layout(&offsets).size()); - let num_tables = runtime_info.component().num_resource_tables; - let mut component_resource_tables = PrimaryMap::with_capacity(num_tables); - for _ in 0..num_tables { + let num_resource_tables = runtime_info.component().num_resource_tables; + let mut component_resource_tables = PrimaryMap::with_capacity(num_resource_tables); + for _ in 0..num_resource_tables { component_resource_tables.push(ResourceTable::default()); } + let num_waitable_tables = runtime_info.component().num_runtime_component_instances; + let mut component_waitable_tables = + PrimaryMap::with_capacity(usize::try_from(num_waitable_tables).unwrap()); + for _ in 0..num_waitable_tables { + component_waitable_tables.push(StateTable::default()); + } + + let num_error_context_tables = runtime_info.component().num_error_context_tables; + let mut component_error_context_tables = PrimaryMap::< + TypeComponentLocalErrorContextTableIndex, + StateTable, + >::with_capacity(num_error_context_tables); + for _ in 0..num_error_context_tables { + component_error_context_tables.push(StateTable::default()); + } + + let component_global_error_context_ref_counts = BTreeMap::new(); + ptr::write( ptr.as_ptr(), ComponentInstance { @@ -224,6 +312,9 @@ impl ComponentInstance { .unwrap(), ), component_resource_tables, + component_waitable_tables, + component_error_context_tables, + component_global_error_context_ref_counts, runtime_info, resource_types, store: VMStoreRawPtr(store), @@ -298,6 +389,18 @@ impl ComponentInstance { } } + /// Returns the async callback pointer corresponding to the index provided. + /// + /// This can only be called after `idx` has been initialized at runtime + /// during the instantiation process of a component. + pub fn runtime_callback(&self, idx: RuntimeCallbackIndex) -> NonNull { + unsafe { + let ret = *self.vmctx_plus_offset::>(self.offsets.runtime_callback(idx)); + debug_assert!(ret.as_ptr() as usize != INVALID_PTR); + ret.as_non_null() + } + } + /// Returns the post-return pointer corresponding to the index provided. /// /// This can only be called after `idx` has been initialized at runtime @@ -479,7 +582,7 @@ impl ComponentInstance { } // In debug mode set non-null bad values to all "pointer looking" bits - // and pices related to lowering and such. This'll help detect any + // and pieces related to lowering and such. This'll help detect any // erroneous usage and enable debug assertions above as well to prevent // loading these before they're configured or setting them twice. if cfg!(debug_assertions) { @@ -604,6 +707,33 @@ impl ComponentInstance { &mut self.component_resource_tables } + /// Retrieves the tables for tracking waitable handles and their states with respect + /// to the components which own them. + pub fn component_waitable_tables( + &mut self, + ) -> &mut PrimaryMap> { + &mut self.component_waitable_tables + } + + /// Retrieves the tables for tracking error-context handles and their reference + /// counts with respect to the components which own them. + pub fn component_error_context_tables( + &mut self, + ) -> &mut PrimaryMap< + TypeComponentLocalErrorContextTableIndex, + StateTable, + > { + &mut self.component_error_context_tables + } + + /// Retrieves the tables for tracking component-global error-context handles + /// and their reference counts with respect to the components which own them. + pub fn component_global_error_context_ref_counts( + &mut self, + ) -> &mut BTreeMap { + &mut self.component_global_error_context_ref_counts + } + /// Returns the destructor and instance flags for the specified resource /// table type. /// @@ -664,6 +794,136 @@ impl ComponentInstance { pub(crate) fn resource_exit_call(&mut self) -> Result<()> { self.resource_tables().exit_call() } + + pub(crate) fn future_transfer( + &mut self, + src_idx: u32, + src: TypeFutureTableIndex, + dst: TypeFutureTableIndex, + ) -> Result { + let src_instance = self.component_types()[src].instance; + let dst_instance = self.component_types()[dst].instance; + let [src_table, dst_table] = self + .component_waitable_tables + .get_many_mut([src_instance, dst_instance]) + .unwrap(); + let (rep, WaitableState::Future(src_ty, src_state)) = + src_table.get_mut_by_index(src_idx)? + else { + bail!("invalid future handle"); + }; + if *src_ty != src { + bail!("invalid future handle"); + } + match src_state { + StreamFutureState::Local => { + *src_state = StreamFutureState::Write; + assert!(dst_table.get_mut_by_rep(rep).is_none()); + dst_table.insert(rep, WaitableState::Future(dst, StreamFutureState::Read)) + } + StreamFutureState::Read => { + src_table.remove_by_index(src_idx)?; + if let Some((dst_idx, dst_state)) = dst_table.get_mut_by_rep(rep) { + let WaitableState::Future(dst_ty, dst_state) = dst_state else { + unreachable!(); + }; + assert_eq!(*dst_ty, dst); + assert_eq!(*dst_state, StreamFutureState::Write); + *dst_state = StreamFutureState::Local; + Ok(dst_idx) + } else { + dst_table.insert(rep, WaitableState::Future(dst, StreamFutureState::Read)) + } + } + StreamFutureState::Write => bail!("cannot transfer write end of future"), + StreamFutureState::Busy => bail!("cannot transfer busy future"), + } + } + + pub(crate) fn stream_transfer( + &mut self, + src_idx: u32, + src: TypeStreamTableIndex, + dst: TypeStreamTableIndex, + ) -> Result { + let src_instance = self.component_types()[src].instance; + let dst_instance = self.component_types()[dst].instance; + let [src_table, dst_table] = self + .component_waitable_tables + .get_many_mut([src_instance, dst_instance]) + .unwrap(); + let (rep, WaitableState::Stream(src_ty, src_state)) = + src_table.get_mut_by_index(src_idx)? + else { + bail!("invalid stream handle"); + }; + if *src_ty != src { + bail!("invalid stream handle"); + } + match src_state { + StreamFutureState::Local => { + *src_state = StreamFutureState::Write; + assert!(dst_table.get_mut_by_rep(rep).is_none()); + dst_table.insert(rep, WaitableState::Stream(dst, StreamFutureState::Read)) + } + StreamFutureState::Read => { + src_table.remove_by_index(src_idx)?; + if let Some((dst_idx, dst_state)) = dst_table.get_mut_by_rep(rep) { + let WaitableState::Stream(dst_ty, dst_state) = dst_state else { + unreachable!(); + }; + assert_eq!(*dst_ty, dst); + assert_eq!(*dst_state, StreamFutureState::Write); + *dst_state = StreamFutureState::Local; + Ok(dst_idx) + } else { + dst_table.insert(rep, WaitableState::Stream(dst, StreamFutureState::Read)) + } + } + StreamFutureState::Write => bail!("cannot transfer write end of stream"), + StreamFutureState::Busy => bail!("cannot transfer busy stream"), + } + } + + /// Transfer the state of a given error context from one component to another + pub(crate) fn error_context_transfer( + &mut self, + src_idx: u32, + src: TypeComponentLocalErrorContextTableIndex, + dst: TypeComponentLocalErrorContextTableIndex, + ) -> Result { + let (rep, _) = { + let rep = self + .component_error_context_tables + .get_mut(src) + .context("error context table index present in (sub)component lookup")? + .get_mut_by_index(src_idx)?; + rep + }; + let dst = self + .component_error_context_tables + .get_mut(dst) + .context("error context table index present in (sub)component lookup")?; + + // Update the component local for the destination + let updated_count = if let Some((dst_idx, count)) = dst.get_mut_by_rep(rep) { + (*count).0 += 1; + dst_idx + } else { + dst.insert(rep, LocalErrorContextRefCount(1))? + }; + + // Update the global (cross-subcomponent) count for error contexts + // as the new component has essentially created a new reference that will + // be dropped/handled independently + let global_ref_count = self + .component_global_error_context_ref_counts + .get_mut(&TypeComponentGlobalErrorContextTableIndex::from_u32(rep)) + .context("global ref count present for existing (sub)component error context")?; + global_ref_count.0 += 1; + + Ok(updated_count) + } } impl VMComponentContext { @@ -684,7 +944,7 @@ impl VMComponentContext { /// This type can be dereferenced to `ComponentInstance` to access the /// underlying methods. pub struct OwnedComponentInstance { - ptr: SendSyncPtr, + pub(crate) ptr: SendSyncPtr, } impl OwnedComponentInstance { diff --git a/crates/wasmtime/src/runtime/vm/component/error_contexts.rs b/crates/wasmtime/src/runtime/vm/component/error_contexts.rs new file mode 100644 index 000000000000..435197f79e5e --- /dev/null +++ b/crates/wasmtime/src/runtime/vm/component/error_contexts.rs @@ -0,0 +1,16 @@ +/// Error context reference count local to a given (sub)component +/// +/// This reference count is localized to a single (sub)component, +/// rather than the global cross-component count (i.e. that determines +/// when a error context can be completely removed) +#[repr(transparent)] +pub struct LocalErrorContextRefCount(pub(crate) usize); + +/// Error context reference count across a [`ComponentInstance`] +/// +/// Contrasted to `LocalErrorContextRefCount`, this count is maintained +/// across all sub-components in a given component. +/// +/// When this count is zero it is *definitely* safe to remove an error context. +#[repr(transparent)] +pub struct GlobalErrorContextRefCount(pub(crate) usize); diff --git a/crates/wasmtime/src/runtime/vm/component/libcalls.rs b/crates/wasmtime/src/runtime/vm/component/libcalls.rs index e2720b775cf2..f066c8f2b6ae 100644 --- a/crates/wasmtime/src/runtime/vm/component/libcalls.rs +++ b/crates/wasmtime/src/runtime/vm/component/libcalls.rs @@ -7,7 +7,10 @@ use core::cell::Cell; use core::convert::Infallible; use core::ptr::NonNull; use core::slice; -use wasmtime_environ::component::TypeResourceTableIndex; +use wasmtime_environ::component::{ + TypeComponentLocalErrorContextTableIndex, TypeFutureTableIndex, TypeResourceTableIndex, + TypeStreamTableIndex, +}; const UTF16_TAG: usize = 1 << 31; @@ -606,6 +609,7 @@ unsafe fn task_return( ) -> Result<()> { ComponentInstance::from_vmctx(vmctx, |instance| { (*instance.store()).component_async_store().task_return( + instance, wasmtime_environ::component::TypeTupleIndex::from_u32(ty), storage.cast::(), storage_len, @@ -623,6 +627,7 @@ unsafe fn task_wait( ) -> Result { ComponentInstance::from_vmctx(vmctx, |instance| { (*instance.store()).component_async_store().task_wait( + instance, wasmtime_environ::component::RuntimeComponentInstanceIndex::from_u32(caller_instance), async_ != 0, memory.cast::(), @@ -641,6 +646,7 @@ unsafe fn task_poll( ) -> Result { ComponentInstance::from_vmctx(vmctx, |instance| { (*instance.store()).component_async_store().task_poll( + instance, wasmtime_environ::component::RuntimeComponentInstanceIndex::from_u32(caller_instance), async_ != 0, memory.cast::(), @@ -654,7 +660,7 @@ unsafe fn task_yield(vmctx: NonNull, async_: u8) -> Result<( ComponentInstance::from_vmctx(vmctx, |instance| { (*instance.store()) .component_async_store() - .task_yield(async_ != 0) + .task_yield(instance, async_ != 0) }) } @@ -666,6 +672,7 @@ unsafe fn subtask_drop( ) -> Result<()> { ComponentInstance::from_vmctx(vmctx, |instance| { (*instance.store()).component_async_store().subtask_drop( + instance, wasmtime_environ::component::RuntimeComponentInstanceIndex::from_u32(caller_instance), task_id, ) @@ -708,6 +715,7 @@ unsafe fn async_exit( ) -> Result { ComponentInstance::from_vmctx(vmctx, |instance| { (*instance.store()).component_async_store().async_exit( + instance, callback.cast::(), post_return.cast::(), wasmtime_environ::component::RuntimeComponentInstanceIndex::from_u32(caller_instance), @@ -720,32 +728,433 @@ unsafe fn async_exit( }) } +#[cfg(feature = "component-model-async")] unsafe fn future_transfer( vmctx: NonNull, src_idx: u32, src_table: u32, dst_table: u32, ) -> Result { - _ = (vmctx, src_idx, src_table, dst_table); - todo!() + let src_table = TypeFutureTableIndex::from_u32(src_table); + let dst_table = TypeFutureTableIndex::from_u32(dst_table); + ComponentInstance::from_vmctx(vmctx, |instance| { + instance.future_transfer(src_idx, src_table, dst_table) + }) } +#[cfg(feature = "component-model-async")] unsafe fn stream_transfer( vmctx: NonNull, src_idx: u32, src_table: u32, dst_table: u32, ) -> Result { - _ = (vmctx, src_idx, src_table, dst_table); - todo!() + let src_table = TypeStreamTableIndex::from_u32(src_table); + let dst_table = TypeStreamTableIndex::from_u32(dst_table); + ComponentInstance::from_vmctx(vmctx, |instance| { + instance.stream_transfer(src_idx, src_table, dst_table) + }) } +#[cfg(feature = "component-model-async")] unsafe fn error_context_transfer( vmctx: NonNull, src_idx: u32, src_table: u32, dst_table: u32, ) -> Result { - _ = (vmctx, src_idx, src_table, dst_table); - todo!() + let src_table = TypeComponentLocalErrorContextTableIndex::from_u32(src_table); + let dst_table = TypeComponentLocalErrorContextTableIndex::from_u32(dst_table); + ComponentInstance::from_vmctx(vmctx, |instance| { + instance.error_context_transfer(src_idx, src_table, dst_table) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn future_new(vmctx: NonNull, ty: u32) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()).component_async_store().future_new( + instance, + wasmtime_environ::component::TypeFutureTableIndex::from_u32(ty), + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn future_write( + vmctx: NonNull, + memory: *mut u8, + realloc: *mut u8, + string_encoding: u8, + ty: u32, + future: u32, + address: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()).component_async_store().future_write( + instance, + memory.cast::(), + realloc.cast::(), + string_encoding, + wasmtime_environ::component::TypeFutureTableIndex::from_u32(ty), + future, + address, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn future_read( + vmctx: NonNull, + memory: *mut u8, + realloc: *mut u8, + string_encoding: u8, + ty: u32, + future: u32, + address: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()).component_async_store().future_read( + instance, + memory.cast::(), + realloc.cast::(), + string_encoding, + wasmtime_environ::component::TypeFutureTableIndex::from_u32(ty), + future, + address, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn future_cancel_write( + vmctx: NonNull, + ty: u32, + async_: u8, + writer: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .future_cancel_write( + instance, + wasmtime_environ::component::TypeFutureTableIndex::from_u32(ty), + async_ != 0, + writer, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn future_cancel_read( + vmctx: NonNull, + ty: u32, + async_: u8, + reader: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .future_cancel_read( + instance, + wasmtime_environ::component::TypeFutureTableIndex::from_u32(ty), + async_ != 0, + reader, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn future_close_writable( + vmctx: NonNull, + ty: u32, + writer: u32, + error: u32, +) -> Result<()> { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .future_close_writable( + instance, + wasmtime_environ::component::TypeFutureTableIndex::from_u32(ty), + writer, + error, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn future_close_readable( + vmctx: NonNull, + ty: u32, + reader: u32, +) -> Result<()> { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .future_close_readable( + instance, + wasmtime_environ::component::TypeFutureTableIndex::from_u32(ty), + reader, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn stream_new(vmctx: NonNull, ty: u32) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()).component_async_store().stream_new( + instance, + wasmtime_environ::component::TypeStreamTableIndex::from_u32(ty), + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn stream_write( + vmctx: NonNull, + memory: *mut u8, + realloc: *mut u8, + string_encoding: u8, + ty: u32, + stream: u32, + address: u32, + count: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()).component_async_store().stream_write( + instance, + memory.cast::(), + realloc.cast::(), + string_encoding, + wasmtime_environ::component::TypeStreamTableIndex::from_u32(ty), + stream, + address, + count, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn stream_read( + vmctx: NonNull, + memory: *mut u8, + realloc: *mut u8, + string_encoding: u8, + ty: u32, + stream: u32, + address: u32, + count: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()).component_async_store().stream_read( + instance, + memory.cast::(), + realloc.cast::(), + string_encoding, + wasmtime_environ::component::TypeStreamTableIndex::from_u32(ty), + stream, + address, + count, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn stream_cancel_write( + vmctx: NonNull, + ty: u32, + async_: u8, + writer: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .stream_cancel_write( + instance, + wasmtime_environ::component::TypeStreamTableIndex::from_u32(ty), + async_ != 0, + writer, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn stream_cancel_read( + vmctx: NonNull, + ty: u32, + async_: u8, + reader: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .stream_cancel_read( + instance, + wasmtime_environ::component::TypeStreamTableIndex::from_u32(ty), + async_ != 0, + reader, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn stream_close_writable( + vmctx: NonNull, + ty: u32, + writer: u32, + error: u32, +) -> Result<()> { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .stream_close_writable( + instance, + wasmtime_environ::component::TypeStreamTableIndex::from_u32(ty), + writer, + error, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn stream_close_readable( + vmctx: NonNull, + ty: u32, + reader: u32, +) -> Result<()> { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .stream_close_readable( + instance, + wasmtime_environ::component::TypeStreamTableIndex::from_u32(ty), + reader, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn flat_stream_write( + vmctx: NonNull, + memory: *mut u8, + realloc: *mut u8, + ty: u32, + payload_size: u32, + payload_align: u32, + stream: u32, + address: u32, + count: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .flat_stream_write( + instance, + memory.cast::(), + realloc.cast::(), + wasmtime_environ::component::TypeStreamTableIndex::from_u32(ty), + payload_size, + payload_align, + stream, + address, + count, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn flat_stream_read( + vmctx: NonNull, + memory: *mut u8, + realloc: *mut u8, + ty: u32, + payload_size: u32, + payload_align: u32, + stream: u32, + address: u32, + count: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .flat_stream_read( + instance, + memory.cast::(), + realloc.cast::(), + wasmtime_environ::component::TypeStreamTableIndex::from_u32(ty), + payload_size, + payload_align, + stream, + address, + count, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn error_context_new( + vmctx: NonNull, + memory: *mut u8, + realloc: *mut u8, + string_encoding: u8, + ty: u32, + debug_msg_address: u32, + debug_msg_len: u32, +) -> Result { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .error_context_new( + instance, + memory.cast::(), + realloc.cast::(), + string_encoding, + wasmtime_environ::component::TypeComponentLocalErrorContextTableIndex::from_u32(ty), + debug_msg_address, + debug_msg_len, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn error_context_debug_message( + vmctx: NonNull, + memory: *mut u8, + realloc: *mut u8, + string_encoding: u8, + ty: u32, + err_ctx_handle: u32, + debug_msg_address: u32, +) -> Result<()> { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .error_context_debug_message( + instance, + memory.cast::(), + realloc.cast::(), + string_encoding, + wasmtime_environ::component::TypeComponentLocalErrorContextTableIndex::from_u32(ty), + err_ctx_handle, + debug_msg_address, + ) + }) +} + +#[cfg(feature = "component-model-async")] +unsafe fn error_context_drop( + vmctx: NonNull, + ty: u32, + err_ctx_handle: u32, +) -> Result<()> { + ComponentInstance::from_vmctx(vmctx, |instance| { + (*instance.store()) + .component_async_store() + .error_context_drop( + instance, + wasmtime_environ::component::TypeComponentLocalErrorContextTableIndex::from_u32(ty), + err_ctx_handle, + ) + }) } diff --git a/crates/wasmtime/src/runtime/vm/component/resources.rs b/crates/wasmtime/src/runtime/vm/component/resources.rs index 8ea8b7f15f87..de3cb8e3c653 100644 --- a/crates/wasmtime/src/runtime/vm/component/resources.rs +++ b/crates/wasmtime/src/runtime/vm/component/resources.rs @@ -107,8 +107,19 @@ pub struct CallContexts { scopes: Vec, } +impl CallContexts { + pub fn push(&mut self, cx: CallContext) { + self.scopes.push(cx); + } + + pub fn pop(&mut self) -> Option { + self.scopes.pop() + } +} + +/// State related to borrows for a specific call. #[derive(Default)] -struct CallContext { +pub struct CallContext { lenders: Vec, borrow_count: u32, } diff --git a/crates/wasmtime/src/runtime/vm/component/states.rs b/crates/wasmtime/src/runtime/vm/component/states.rs new file mode 100644 index 000000000000..2f09de1eea93 --- /dev/null +++ b/crates/wasmtime/src/runtime/vm/component/states.rs @@ -0,0 +1,134 @@ +use { + alloc::vec::Vec, + anyhow::{bail, Result}, + core::mem, +}; + +/// The maximum handle value is specified in +/// +/// currently and keeps the upper bit free for use in the component. +const MAX_HANDLE: u32 = 1 << 30; + +enum Slot { + Free { next: u32 }, + Occupied { rep: u32, state: T }, +} + +pub struct StateTable { + next: u32, + slots: Vec>, + // TODO: This is a sparse table (where zero means "no entry"); it might make + // more sense to use a `HashMap` here, but we'd need one that's + // no_std-compatible. A `BTreeMap` might also be appropriate if we restrict + // ourselves to `alloc::collections`. + reps_to_indexes: Vec, +} + +impl Default for StateTable { + fn default() -> Self { + Self { + next: 0, + slots: Vec::new(), + reps_to_indexes: Vec::new(), + } + } +} + +impl StateTable { + pub fn insert(&mut self, rep: u32, state: T) -> Result { + if matches!(self + .reps_to_indexes + .get(usize::try_from(rep).unwrap()), Some(idx) if *idx != 0) + { + bail!("rep {rep} already exists in this table"); + } + + let next = self.next as usize; + if next == self.slots.len() { + self.slots.push(Slot::Free { + next: self.next.checked_add(1).unwrap(), + }); + } + let ret = self.next; + self.next = match mem::replace(&mut self.slots[next], Slot::Occupied { rep, state }) { + Slot::Free { next } => next, + _ => unreachable!(), + }; + // The component model reserves index 0 as never allocatable so add one + // to the table index to start the numbering at 1 instead. Also note + // that the component model places an upper-limit per-table on the + // maximum allowed index. + let ret = ret + 1; + if ret >= MAX_HANDLE { + bail!("cannot allocate another handle: index overflow"); + } + + let rep = usize::try_from(rep).unwrap(); + if self.reps_to_indexes.len() <= rep { + self.reps_to_indexes.resize(rep.checked_add(1).unwrap(), 0); + } + + self.reps_to_indexes[rep] = ret; + + Ok(ret) + } + + fn handle_index_to_table_index(&self, idx: u32) -> Option { + // NB: `idx` is decremented by one to account for the `+1` above during + // allocation. + let idx = idx.checked_sub(1)?; + usize::try_from(idx).ok() + } + + fn get_mut(&mut self, idx: u32) -> Result<&mut Slot> { + let slot = self + .handle_index_to_table_index(idx) + .and_then(|i| self.slots.get_mut(i)); + match slot { + None | Some(Slot::Free { .. }) => bail!("unknown handle index {idx}"), + Some(slot) => Ok(slot), + } + } + + pub fn has_handle(&self, idx: u32) -> bool { + matches!( + self.handle_index_to_table_index(idx) + .and_then(|i| self.slots.get(i)), + Some(Slot::Occupied { .. }) + ) + } + + pub fn get_mut_by_index(&mut self, idx: u32) -> Result<(u32, &mut T)> { + let slot = self + .handle_index_to_table_index(idx) + .and_then(|i| self.slots.get_mut(i)); + match slot { + None | Some(Slot::Free { .. }) => bail!("unknown handle index {idx}"), + Some(Slot::Occupied { rep, state }) => Ok((*rep, state)), + } + } + + pub fn get_mut_by_rep(&mut self, rep: u32) -> Option<(u32, &mut T)> { + let index = *self.reps_to_indexes.get(usize::try_from(rep).unwrap())?; + if index > 0 { + let (_, state) = self.get_mut_by_index(index).unwrap(); + Some((index, state)) + } else { + None + } + } + + pub fn remove_by_index(&mut self, idx: u32) -> Result<(u32, T)> { + let to_fill = Slot::Free { next: self.next }; + let Slot::Occupied { rep, state } = mem::replace(self.get_mut(idx)?, to_fill) else { + unreachable!() + }; + self.next = idx - 1; + { + let rep = usize::try_from(rep).unwrap(); + assert_eq!(idx, self.reps_to_indexes[rep]); + self.reps_to_indexes[rep] = 0; + } + Ok((rep, state)) + } +} diff --git a/crates/wasmtime/src/runtime/vm/interpreter.rs b/crates/wasmtime/src/runtime/vm/interpreter.rs index f57e2551cc15..886371b5105e 100644 --- a/crates/wasmtime/src/runtime/vm/interpreter.rs +++ b/crates/wasmtime/src/runtime/vm/interpreter.rs @@ -378,7 +378,7 @@ impl InterpreterRef<'_> { use wasmtime_environ::component::ComponentBuiltinFunctionIndex; if id == const { HostCall::ComponentLowerImport.index() } { - call!(@host VMLoweringCallee(nonnull, nonnull, u32, nonnull, ptr, ptr, u8, u8, nonnull, size) -> bool); + call!(@host VMLoweringCallee(nonnull, nonnull, u32, u32, nonnull, ptr, ptr, u8, u8, nonnull, size) -> bool); } macro_rules! component { diff --git a/crates/wasmtime/src/runtime/vm/traphandlers.rs b/crates/wasmtime/src/runtime/vm/traphandlers.rs index 91a8ca6b27f3..846cd52c1612 100644 --- a/crates/wasmtime/src/runtime/vm/traphandlers.rs +++ b/crates/wasmtime/src/runtime/vm/traphandlers.rs @@ -360,39 +360,45 @@ where F: FnMut(NonNull, Option>) -> bool, { let caller = store.0.default_caller(); - let result = CallThreadState::new(store.0, caller).with(|cx| match store.0.executor() { - // In interpreted mode directly invoke the host closure since we won't - // be using host-based `setjmp`/`longjmp` as that's not going to save - // the context we want. - ExecutorRef::Interpreter(r) => { - cx.jmp_buf - .set(CallThreadState::JMP_BUF_INTERPRETER_SENTINEL); - closure(caller, Some(r)) - } + let async_guard_range = store.async_guard_range(); + let result = CallThreadState::new(store.0, async_guard_range, caller).with(|cx| { + match store.0.executor() { + // In interpreted mode directly invoke the host closure since we won't + // be using host-based `setjmp`/`longjmp` as that's not going to save + // the context we want. + ExecutorRef::Interpreter(r) => { + cx.jmp_buf + .set(CallThreadState::JMP_BUF_INTERPRETER_SENTINEL); + closure(caller, Some(r)) + } - // In native mode, however, defer to C to do the `setjmp` since Rust - // doesn't understand `setjmp`. - // - // Note that here we pass a function pointer to C to catch longjmp - // within, here it's `call_closure`, and that passes `None` for the - // interpreter since this branch is only ever taken if the interpreter - // isn't present. - #[cfg(has_host_compiler_backend)] - ExecutorRef::Native => traphandlers::wasmtime_setjmp( - cx.jmp_buf.as_ptr(), - { - extern "C" fn call_closure(payload: *mut u8, caller: NonNull) -> bool - where - F: FnMut(NonNull, Option>) -> bool, + // In native mode, however, defer to C to do the `setjmp` since Rust + // doesn't understand `setjmp`. + // + // Note that here we pass a function pointer to C to catch longjmp + // within, here it's `call_closure`, and that passes `None` for the + // interpreter since this branch is only ever taken if the interpreter + // isn't present. + #[cfg(has_host_compiler_backend)] + ExecutorRef::Native => traphandlers::wasmtime_setjmp( + cx.jmp_buf.as_ptr(), { - unsafe { (*(payload as *mut F))(caller, None) } - } - - call_closure:: - }, - &mut closure as *mut F as *mut u8, - caller, - ), + extern "C" fn call_closure( + payload: *mut u8, + caller: NonNull, + ) -> bool + where + F: FnMut(NonNull, Option>) -> bool, + { + unsafe { (*(payload as *mut F))(caller, None) } + } + + call_closure:: + }, + &mut closure as *mut F as *mut u8, + caller, + ), + } }); return match result { @@ -463,7 +469,11 @@ mod call_thread_state { pub const JMP_BUF_INTERPRETER_SENTINEL: *mut u8 = 1 as *mut u8; #[inline] - pub(super) fn new(store: &mut StoreOpaque, caller: NonNull) -> CallThreadState { + pub(super) fn new( + store: &mut StoreOpaque, + async_guard_range: Range<*mut u8>, + caller: NonNull, + ) -> CallThreadState { let limits = unsafe { Instance::from_vmctx(caller, |i| i.runtime_limits()) .read() @@ -473,7 +483,7 @@ mod call_thread_state { // Don't try to plumb #[cfg] everywhere for this field, just pretend // we're using it on miri/windows to silence compiler warnings. - let _: Range<_> = store.async_guard_range(); + let _: Range<_> = async_guard_range; CallThreadState { unwind: Cell::new(None), @@ -486,7 +496,7 @@ mod call_thread_state { capture_coredump: store.engine().config().coredump_on_trap, limits, #[cfg(all(has_native_signals, unix))] - async_guard_range: store.async_guard_range(), + async_guard_range, prev: Cell::new(ptr::null()), old_last_wasm_exit_fp: Cell::new(unsafe { *limits.as_ref().last_wasm_exit_fp.get() diff --git a/crates/wasmtime/src/runtime/wave/component.rs b/crates/wasmtime/src/runtime/wave/component.rs index 238512012f1c..39c615bfccdf 100644 --- a/crates/wasmtime/src/runtime/wave/component.rs +++ b/crates/wasmtime/src/runtime/wave/component.rs @@ -41,7 +41,11 @@ impl WasmType for component::Type { Self::Result(_) => WasmTypeKind::Result, Self::Flags(_) => WasmTypeKind::Flags, - Self::Own(_) | Self::Borrow(_) => WasmTypeKind::Unsupported, + Self::Own(_) + | Self::Borrow(_) + | Self::Stream(_) + | Self::Future(_) + | Self::ErrorContext => WasmTypeKind::Unsupported, } } @@ -134,7 +138,9 @@ impl WasmValue for component::Val { Self::Option(_) => WasmTypeKind::Option, Self::Result(_) => WasmTypeKind::Result, Self::Flags(_) => WasmTypeKind::Flags, - Self::Resource(_) => WasmTypeKind::Unsupported, + Self::Resource(_) | Self::Stream(_) | Self::Future(_) | Self::ErrorContext(_) => { + WasmTypeKind::Unsupported + } } } diff --git a/crates/wast/src/component.rs b/crates/wast/src/component.rs index 8a7f19dc08d0..1346a7f11361 100644 --- a/crates/wast/src/component.rs +++ b/crates/wast/src/component.rs @@ -284,6 +284,9 @@ fn mismatch(expected: &WastVal<'_>, actual: &Val) -> Result<()> { Val::Result(..) => "result", Val::Flags(..) => "flags", Val::Resource(..) => "resource", + Val::Future(..) => "future", + Val::Stream(..) => "stream", + Val::ErrorContext(..) => "error-context", }; bail!("expected `{expected}` got `{actual}`") } diff --git a/tests/all/component_model/bindgen.rs b/tests/all/component_model/bindgen.rs index e89b04f0ca09..344105a8f0f5 100644 --- a/tests/all/component_model/bindgen.rs +++ b/tests/all/component_model/bindgen.rs @@ -5,7 +5,7 @@ use super::engine; use anyhow::Result; use wasmtime::{ component::{Component, Linker}, - Store, + Config, Engine, Store, }; mod ownership; @@ -58,6 +58,72 @@ mod no_imports { } } +mod no_imports_concurrent { + use super::*; + use wasmtime::component::PromisesUnordered; + + wasmtime::component::bindgen!({ + inline: " + package foo:foo; + + world no-imports { + export foo: interface { + foo: func(); + } + + export bar: func(); + } + ", + async: true, + concurrent_exports: true, + }); + + #[tokio::test] + async fn run() -> Result<()> { + let mut config = Config::new(); + config.wasm_component_model_async(true); + config.async_support(true); + let engine = &Engine::new(&config)?; + + let component = Component::new( + &engine, + r#" + (component + (core module $m + (import "" "task.return" (func $task-return)) + (func (export "bar") (result i32) + call $task-return + i32.const 0 + ) + (func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable) + ) + (core func $task-return (canon task.return)) + (core instance $i (instantiate $m + (with "" (instance (export "task.return" (func $task-return)))) + )) + + (func $f (export "bar") + (canon lift (core func $i "bar") async (callback (func $i "callback"))) + ) + + (instance $i (export "foo" (func $f))) + (export "foo" (instance $i)) + ) + "#, + )?; + + let linker = Linker::new(&engine); + let mut store = Store::new(&engine, ()); + let no_imports = NoImports::instantiate_async(&mut store, &component, &linker).await?; + let mut promises = PromisesUnordered::new(); + promises.push(no_imports.call_bar(&mut store).await?); + promises.push(no_imports.foo().call_foo(&mut store).await?); + assert!(promises.next(&mut store).await?.is_some()); + assert!(promises.next(&mut store).await?.is_some()); + Ok(()) + } +} + mod one_import { use super::*; @@ -121,6 +187,110 @@ mod one_import { } } +mod one_import_concurrent { + use { + super::*, + std::future::Future, + wasmtime::{component, StoreContextMut}, + }; + + wasmtime::component::bindgen!({ + inline: " + package foo:foo; + + world no-imports { + import foo: interface { + foo: func(); + } + + export bar: func(); + } + ", + async: true, + concurrent_imports: true, + concurrent_exports: true, + }); + + #[tokio::test] + async fn run() -> Result<()> { + let mut config = Config::new(); + config.wasm_component_model_async(true); + config.async_support(true); + let engine = &Engine::new(&config)?; + + let component = Component::new( + &engine, + r#" + (component + (import "foo" (instance $foo-instance + (export "foo" (func)) + )) + (core module $libc + (memory (export "memory") 1) + ) + (core instance $libc-instance (instantiate $libc)) + (core module $m + (import "" "foo" (func $foo (param i32 i32) (result i32))) + (import "" "task.return" (func $task-return)) + (func (export "bar") (result i32) + i32.const 0 + i32.const 0 + call $foo + drop + call $task-return + i32.const 0 + ) + (func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable) + ) + (core func $foo (canon lower (func $foo-instance "foo") async (memory $libc-instance "memory"))) + (core func $task-return (canon task.return)) + (core instance $i (instantiate $m + (with "" (instance + (export "task.return" (func $task-return)) + (export "foo" (func $foo)) + )) + )) + + (func $f (export "bar") + (canon lift (core func $i "bar") async (callback (func $i "callback"))) + ) + + (instance $i (export "foo" (func $f))) + (export "foo" (instance $i)) + ) + "#, + )?; + + #[derive(Default)] + struct MyImports { + hit: bool, + } + + impl foo::Host for MyImports { + type Data = MyImports; + + fn foo( + mut store: StoreContextMut<'_, Self::Data>, + ) -> impl Future) + 'static> + + Send + + Sync + + 'static { + store.data_mut().hit = true; + async { component::for_any(|_| ()) } + } + } + + let mut linker = Linker::new(&engine); + foo::add_to_linker(&mut linker, |f: &mut MyImports| f)?; + let mut store = Store::new(&engine, MyImports::default()); + let no_imports = NoImports::instantiate_async(&mut store, &component, &linker).await?; + let promise = no_imports.call_bar(&mut store).await?; + promise.get(&mut store).await?; + assert!(store.data().hit); + Ok(()) + } +} + mod resources_at_world_level { use super::*; use wasmtime::component::Resource; diff --git a/tests/all/component_model/call_hook.rs b/tests/all/component_model/call_hook.rs index 91f71151aa48..5064a6b7f3d6 100644 --- a/tests/all/component_model/call_hook.rs +++ b/tests/all/component_model/call_hook.rs @@ -610,12 +610,15 @@ async fn drop_suspended_async_hook() -> Result<()> { times: u32, } - impl Future for PollNTimes { + impl Future for PollNTimes + where + F::Output: std::fmt::Debug, + { type Output = (); fn poll(mut self: Pin<&mut Self>, task: &mut task::Context<'_>) -> Poll<()> { for i in 0..self.times { match Pin::new(&mut self.future).poll(task) { - Poll::Ready(_) => panic!("future should not be ready at {i}"), + Poll::Ready(v) => panic!("future should not be ready at {i}; result is {v:?}"), Poll::Pending => {} } } diff --git a/tests/all/component_model/dynamic.rs b/tests/all/component_model/dynamic.rs index a27fd52df6e2..a6417b07d3a2 100644 --- a/tests/all/component_model/dynamic.rs +++ b/tests/all/component_model/dynamic.rs @@ -87,7 +87,7 @@ fn primitives() -> Result<()> { .call_and_post_return(&mut store, &output, &mut []) .unwrap_err(); assert!( - err.to_string().contains("expected 1 results(s), got 0"), + err.to_string().contains("expected 1 result(s), got 0"), "{err}" ); diff --git a/tests/all/component_model/func.rs b/tests/all/component_model/func.rs index 2632a830b348..e60f554ce3bc 100644 --- a/tests/all/component_model/func.rs +++ b/tests/all/component_model/func.rs @@ -821,13 +821,219 @@ fn strings() -> Result<()> { Ok(()) } -#[test] -fn many_parameters() -> Result<()> { - let component = format!( +#[tokio::test] +async fn async_reentrance() -> Result<()> { + let component = r#" + (component + (core module $shim + (import "" "task.return" (func $task-return (param i32))) + (table (export "funcs") 1 1 funcref) + (func (export "export") (param i32) (result i32) + (call_indirect (i32.const 0) (local.get 0)) + ) + (func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable) + ) + (core func $task-return (canon task.return (result u32))) + (core instance $shim (instantiate $shim + (with "" (instance (export "task.return" (func $task-return)))) + )) + (func $shim-export (param "p1" u32) (result u32) + (canon lift (core func $shim "export") async (callback (func $shim "callback"))) + ) + + (component $inner + (import "import" (func $import (param "p1" u32) (result u32))) + (core module $libc (memory (export "memory") 1)) + (core instance $libc (instantiate $libc)) + (core func $import (canon lower (func $import) async (memory $libc "memory"))) + + (core module $m + (import "libc" "memory" (memory 1)) + (import "" "import" (func $import (param i32 i32) (result i32))) + (import "" "task.return" (func $task-return (param i32))) + (func (export "export") (param i32) (result i32) + (i32.store offset=0 (i32.const 1200) (local.get 0)) + (call $import (i32.const 1200) (i32.const 1204)) + drop + (call $task-return (i32.load offset=0 (i32.const 1204))) + i32.const 0 + ) + (func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable) + ) + (core type $task-return-type (func (param i32))) + (core func $task-return (canon task.return (result u32))) + (core instance $i (instantiate $m + (with "" (instance + (export "task.return" (func $task-return)) + (export "import" (func $import)) + )) + (with "libc" (instance $libc)) + )) + (func (export "export") (param "p1" u32) (result u32) + (canon lift (core func $i "export") async (callback (func $i "callback"))) + ) + ) + (instance $inner (instantiate $inner (with "import" (func $shim-export)))) + + (core module $libc (memory (export "memory") 1)) + (core instance $libc (instantiate $libc)) + (core func $inner-export (canon lower (func $inner "export") async (memory $libc "memory"))) + + (core module $donut + (import "" "funcs" (table 1 1 funcref)) + (import "libc" "memory" (memory 1)) + (import "" "import" (func $import (param i32 i32) (result i32))) + (import "" "task.return" (func $task-return (param i32))) + (func $host-export (export "export") (param i32) (result i32) + (i32.store offset=0 (i32.const 1200) (local.get 0)) + (call $import (i32.const 1200) (i32.const 1204)) + drop + (call $task-return (i32.load offset=0 (i32.const 1204))) + i32.const 0 + ) + (func $guest-export (export "guest-export") (param i32) (result i32) unreachable) + (func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable) + (func $start + (table.set (i32.const 0) (ref.func $guest-export)) + ) + (start $start) + ) + + (core instance $donut (instantiate $donut + (with "" (instance + (export "task.return" (func $task-return)) + (export "import" (func $inner-export)) + (export "funcs" (table $shim "funcs")) + )) + (with "libc" (instance $libc)) + )) + (func (export "export") (param "p1" u32) (result u32) + (canon lift (core func $donut "export") async (callback (func $donut "callback"))) + ) + )"#; + + let mut config = Config::new(); + config.wasm_component_model_async(true); + config.async_support(true); + let engine = &Engine::new(&config)?; + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, ()); + + let instance = Linker::new(&engine) + .instantiate_async(&mut store, &component) + .await?; + + let func = instance.get_typed_func::<(u32,), (u32,)>(&mut store, "export")?; + + match func.call_concurrent(&mut store, (42,)).await { + Ok(_) => panic!(), + Err(e) => assert!(format!("{e:?}").contains("cannot enter component instance")), + } + + Ok(()) +} + +#[tokio::test] +async fn missing_task_return_call_stackless() -> Result<()> { + test_missing_task_return_call(r#"(component + (core module $m + (import "" "task.return" (func $task-return)) + (func (export "foo") (result i32) + i32.const 0 + ) + (func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable) + ) + (core func $task-return (canon task.return)) + (core instance $i (instantiate $m + (with "" (instance (export "task.return" (func $task-return)))) + )) + (func (export "foo") (canon lift (core func $i "foo") async (callback (func $i "callback")))) + )"#).await +} + +#[tokio::test] +async fn missing_task_return_call_stackful() -> Result<()> { + test_missing_task_return_call( r#"(component (core module $m - (memory (export "memory") 1) - (func (export "foo") (param i32) (result i32) + (import "" "task.return" (func $task-return)) + (func (export "foo")) + ) + (core func $task-return (canon task.return)) + (core instance $i (instantiate $m + (with "" (instance (export "task.return" (func $task-return)))) + )) + (func (export "foo") (canon lift (core func $i "foo") async)) + )"#, + ) + .await +} + +async fn test_missing_task_return_call(component: &str) -> Result<()> { + let mut config = Config::new(); + config.wasm_component_model_async(true); + config.async_support(true); + let engine = &Engine::new(&config)?; + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, ()); + + let instance = Linker::new(&engine) + .instantiate_async(&mut store, &component) + .await?; + + let func = instance.get_typed_func::<(), ()>(&mut store, "foo")?; + + match func.call_concurrent(&mut store, ()).await { + Ok(_) => panic!(), + Err(e) => { + assert!(format!("{e:?}") + .contains("wasm trap: async-lifted export failed to produce a result")) + } + } + + Ok(()) +} + +#[tokio::test] +async fn many_parameters() -> Result<()> { + test_many_parameters(false, false).await +} + +#[tokio::test] +async fn many_parameters_concurrent() -> Result<()> { + test_many_parameters(false, true).await +} + +#[tokio::test] +async fn many_parameters_dynamic() -> Result<()> { + test_many_parameters(true, false).await +} + +#[tokio::test] +async fn many_parameters_dynamic_concurrent() -> Result<()> { + test_many_parameters(true, true).await +} + +async fn test_many_parameters(dynamic: bool, concurrent: bool) -> Result<()> { + let (body, async_opts) = if concurrent { + ( + r#" + (call $task-return + (i32.const 0) + (i32.mul + (memory.size) + (i32.const 65536) + ) + (local.get 0) + ) + + (i32.const 0) + "#, + r#"async (callback (func $i "callback"))"#, + ) + } else { + ( + r#" (local $base i32) ;; Allocate space for the return @@ -855,11 +1061,28 @@ fn many_parameters() -> Result<()> { (local.get 0)) (local.get $base) + "#, + "", + ) + }; + + let component = format!( + r#"(component + (core module $m + (import "" "task.return" (func $task-return (param i32 i32 i32))) + (memory (export "memory") 1) + (func (export "foo") (param i32) (result i32) + {body} ) + (func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable) {REALLOC_AND_FREE} ) - (core instance $i (instantiate $m)) + (type $tuple (tuple (list u8) u32)) + (core func $task-return (canon task.return (result $tuple))) + (core instance $i (instantiate $m + (with "" (instance (export "task.return" (func $task-return)))) + )) (type $t (func (param "p1" s8) ;; offset 0, size 1 @@ -870,43 +1093,35 @@ fn many_parameters() -> Result<()> { (param "p6" string) ;; offset 24, size 8 (param "p7" (list u32)) ;; offset 32, size 8 (param "p8" bool) ;; offset 40, size 1 - (param "p0" bool) ;; offset 40, size 1 - (param "pa" char) ;; offset 44, size 4 - (param "pb" (list bool)) ;; offset 48, size 8 - (param "pc" (list char)) ;; offset 56, size 8 - (param "pd" (list string)) ;; offset 64, size 8 + (param "p9" bool) ;; offset 41, size 1 + (param "p0" char) ;; offset 44, size 4 + (param "pa" (list bool)) ;; offset 48, size 8 + (param "pb" (list char)) ;; offset 56, size 8 + (param "pc" (list string)) ;; offset 64, size 8 - (result (tuple (list u8) u32)) + (result $tuple) )) (func (export "many-param") (type $t) (canon lift (core func $i "foo") (memory $i "memory") (realloc (func $i "realloc")) + {async_opts} ) ) )"# ); - let engine = super::engine(); + let mut config = Config::new(); + config.wasm_component_model_async(true); + config.async_support(true); + let engine = &Engine::new(&config)?; let component = Component::new(&engine, component)?; let mut store = Store::new(&engine, ()); - let instance = Linker::new(&engine).instantiate(&mut store, &component)?; - let func = instance.get_typed_func::<( - i8, - u64, - f32, - u8, - i16, - &str, - &[u32], - bool, - bool, - char, - &[bool], - &[char], - &[&str], - ), ((WasmList, u32),)>(&mut store, "many-param")?; + + let instance = Linker::new(&engine) + .instantiate_async(&mut store, &component) + .await?; let input = ( -100, @@ -930,8 +1145,76 @@ fn many_parameters() -> Result<()> { ] .as_slice(), ); - let ((memory, pointer),) = func.call(&mut store, input)?; - let memory = memory.as_le_slice(&store); + + let (memory, pointer) = if dynamic { + let input = vec![ + Val::S8(input.0), + Val::U64(input.1), + Val::Float32(input.2), + Val::U8(input.3), + Val::S16(input.4), + Val::String(input.5.into()), + Val::List(input.6.iter().copied().map(Val::U32).collect()), + Val::Bool(input.7), + Val::Bool(input.8), + Val::Char(input.9), + Val::List(input.10.iter().copied().map(Val::Bool).collect()), + Val::List(input.11.iter().copied().map(Val::Char).collect()), + Val::List(input.12.iter().map(|&s| Val::String(s.into())).collect()), + ]; + let func = instance.get_func(&mut store, "many-param").unwrap(); + + let mut results = if concurrent { + let promise = func.call_concurrent(&mut store, input).await?; + promise.get(&mut store).await?.into_iter() + } else { + let mut results = vec![Val::Bool(false)]; + func.call_async(&mut store, &input, &mut results).await?; + results.into_iter() + }; + + let Some(Val::Tuple(results)) = results.next() else { + panic!() + }; + let mut results = results.into_iter(); + let Some(Val::List(memory)) = results.next() else { + panic!() + }; + let Some(Val::U32(pointer)) = results.next() else { + panic!() + }; + ( + memory + .into_iter() + .map(|v| if let Val::U8(v) = v { v } else { panic!() }) + .collect(), + pointer, + ) + } else { + let func = instance.get_typed_func::<( + i8, + u64, + f32, + u8, + i16, + &str, + &[u32], + bool, + bool, + char, + &[bool], + &[char], + &[&str], + ), ((Vec, u32),)>(&mut store, "many-param")?; + + if concurrent { + let promise = func.call_concurrent(&mut store, input).await?; + promise.get(&mut store).await?.0 + } else { + func.call_async(&mut store, input).await?.0 + } + }; + let memory = &memory[..]; let mut actual = &memory[pointer as usize..][..72]; assert_eq!(i8::from_le_bytes(*actual.take_n::<1>()), input.0); @@ -981,6 +1264,437 @@ fn many_parameters() -> Result<()> { Ok(()) } +#[tokio::test] +async fn many_results() -> Result<()> { + test_many_results(false, false).await +} + +#[tokio::test] +async fn many_results_concurrent() -> Result<()> { + test_many_results(false, true).await +} + +#[tokio::test] +async fn many_results_dynamic() -> Result<()> { + test_many_results(true, false).await +} + +#[tokio::test] +async fn many_results_dynamic_concurrent() -> Result<()> { + test_many_results(true, true).await +} + +async fn test_many_results(dynamic: bool, concurrent: bool) -> Result<()> { + let (ret, async_opts) = if concurrent { + ( + r#" + call $task-return + i32.const 0 + "#, + r#"async (callback (func $i "callback"))"#, + ) + } else { + ("", "") + }; + + let my_nan = CANON_32BIT_NAN | 1; + + let component = format!( + r#"(component + (core module $m + (import "" "task.return" (func $task-return (param i32))) + (memory (export "memory") 1) + (func (export "foo") (result i32) + (local $base i32) + (local $string i32) + (local $list i32) + + (local.set $base + (call $realloc + (i32.const 0) + (i32.const 0) + (i32.const 8) + (i32.const 72))) + + (i32.store8 offset=0 + (local.get $base) + (i32.const -100)) + + (i64.store offset=8 + (local.get $base) + (i64.const 9223372036854775807)) + + (f32.store offset=16 + (local.get $base) + (f32.reinterpret_i32 (i32.const {my_nan}))) + + (i32.store8 offset=20 + (local.get $base) + (i32.const 38)) + + (i32.store16 offset=22 + (local.get $base) + (i32.const 18831)) + + (local.set $string + (call $realloc + (i32.const 0) + (i32.const 0) + (i32.const 1) + (i32.const 6))) + + (i32.store8 offset=0 + (local.get $string) + (i32.const 97)) ;; 'a' + (i32.store8 offset=1 + (local.get $string) + (i32.const 98)) ;; 'b' + (i32.store8 offset=2 + (local.get $string) + (i32.const 99)) ;; 'c' + (i32.store8 offset=3 + (local.get $string) + (i32.const 100)) ;; 'd' + (i32.store8 offset=4 + (local.get $string) + (i32.const 101)) ;; 'e' + (i32.store8 offset=5 + (local.get $string) + (i32.const 102)) ;; 'f' + + (i32.store offset=24 + (local.get $base) + (local.get $string)) + + (i32.store offset=28 + (local.get $base) + (i32.const 2)) + + (local.set $list + (call $realloc + (i32.const 0) + (i32.const 0) + (i32.const 4) + (i32.const 32))) + + (i32.store offset=0 + (local.get $list) + (i32.const 1)) + (i32.store offset=4 + (local.get $list) + (i32.const 2)) + (i32.store offset=8 + (local.get $list) + (i32.const 3)) + (i32.store offset=12 + (local.get $list) + (i32.const 4)) + (i32.store offset=16 + (local.get $list) + (i32.const 5)) + (i32.store offset=20 + (local.get $list) + (i32.const 6)) + (i32.store offset=24 + (local.get $list) + (i32.const 7)) + (i32.store offset=28 + (local.get $list) + (i32.const 8)) + + (i32.store offset=32 + (local.get $base) + (local.get $list)) + + (i32.store offset=36 + (local.get $base) + (i32.const 8)) + + (i32.store8 offset=40 + (local.get $base) + (i32.const 1)) + + (i32.store8 offset=41 + (local.get $base) + (i32.const 0)) + + (i32.store offset=44 + (local.get $base) + (i32.const 128681)) ;; '🚩' + + (local.set $list + (call $realloc + (i32.const 0) + (i32.const 0) + (i32.const 1) + (i32.const 5))) + + (i32.store8 offset=0 + (local.get $list) + (i32.const 0)) + (i32.store8 offset=1 + (local.get $list) + (i32.const 1)) + (i32.store8 offset=2 + (local.get $list) + (i32.const 0)) + (i32.store8 offset=3 + (local.get $list) + (i32.const 1)) + (i32.store8 offset=4 + (local.get $list) + (i32.const 1)) + + (i32.store offset=48 + (local.get $base) + (local.get $list)) + + (i32.store offset=52 + (local.get $base) + (i32.const 5)) + + (local.set $list + (call $realloc + (i32.const 0) + (i32.const 0) + (i32.const 4) + (i32.const 20))) + + (i32.store offset=0 + (local.get $list) + (i32.const 127820)) ;; '🍌' + (i32.store offset=4 + (local.get $list) + (i32.const 129360)) ;; '🥐' + (i32.store offset=8 + (local.get $list) + (i32.const 127831)) ;; '🍗' + (i32.store offset=12 + (local.get $list) + (i32.const 127833)) ;; '🍙' + (i32.store offset=16 + (local.get $list) + (i32.const 127841)) ;; '🍡' + + (i32.store offset=56 + (local.get $base) + (local.get $list)) + + (i32.store offset=60 + (local.get $base) + (i32.const 5)) + + (local.set $list + (call $realloc + (i32.const 0) + (i32.const 0) + (i32.const 4) + (i32.const 16))) + + (i32.store offset=0 + (local.get $list) + (i32.add (local.get $string) (i32.const 2))) + (i32.store offset=4 + (local.get $list) + (i32.const 2)) + (i32.store offset=8 + (local.get $list) + (i32.add (local.get $string) (i32.const 4))) + (i32.store offset=12 + (local.get $list) + (i32.const 2)) + + (i32.store offset=64 + (local.get $base) + (local.get $list)) + + (i32.store offset=68 + (local.get $base) + (i32.const 2)) + + local.get $base + + {ret} + ) + (func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable) + + {REALLOC_AND_FREE} + ) + (type $tuple (tuple + s8 + u64 + float32 + u8 + s16 + string + (list u32) + bool + bool + char + (list bool) + (list char) + (list string) + )) + (core func $task-return (canon task.return (result $tuple))) + (core instance $i (instantiate $m + (with "" (instance (export "task.return" (func $task-return)))) + )) + + (type $t (func (result $tuple))) + (func (export "many-results") (type $t) + (canon lift + (core func $i "foo") + (memory $i "memory") + (realloc (func $i "realloc")) + {async_opts} + ) + ) + )"# + ); + + let mut config = Config::new(); + config.wasm_component_model_async(true); + config.async_support(true); + let engine = &Engine::new(&config)?; + let component = Component::new(&engine, component)?; + let mut store = Store::new(&engine, ()); + + let instance = Linker::new(&engine) + .instantiate_async(&mut store, &component) + .await?; + + let expected = ( + -100i8, + u64::MAX / 2, + f32::from_bits(CANON_32BIT_NAN | 1), + 38u8, + 18831i16, + "ab".to_string(), + vec![1u32, 2, 3, 4, 5, 6, 7, 8], + true, + false, + '🚩', + vec![false, true, false, true, true], + vec!['🍌', '🥐', '🍗', '🍙', '🍡'], + vec!["cd".to_string(), "ef".to_string()], + ); + + let actual = if dynamic { + let func = instance.get_func(&mut store, "many-results").unwrap(); + + let mut results = if concurrent { + let promise = func.call_concurrent(&mut store, Vec::new()).await?; + promise.get(&mut store).await?.into_iter() + } else { + let mut results = vec![Val::Bool(false)]; + func.call_async(&mut store, &[], &mut results).await?; + results.into_iter() + }; + + let Some(Val::Tuple(results)) = results.next() else { + panic!() + }; + let mut results = results.into_iter(); + let Some(Val::S8(p1)) = results.next() else { + panic!() + }; + let Some(Val::U64(p2)) = results.next() else { + panic!() + }; + let Some(Val::Float32(p3)) = results.next() else { + panic!() + }; + let Some(Val::U8(p4)) = results.next() else { + panic!() + }; + let Some(Val::S16(p5)) = results.next() else { + panic!() + }; + let Some(Val::String(p6)) = results.next() else { + panic!() + }; + let Some(Val::List(p7)) = results.next() else { + panic!() + }; + let p7 = p7 + .into_iter() + .map(|v| if let Val::U32(v) = v { v } else { panic!() }) + .collect(); + let Some(Val::Bool(p8)) = results.next() else { + panic!() + }; + let Some(Val::Bool(p9)) = results.next() else { + panic!() + }; + let Some(Val::Char(p0)) = results.next() else { + panic!() + }; + let Some(Val::List(pa)) = results.next() else { + panic!() + }; + let pa = pa + .into_iter() + .map(|v| if let Val::Bool(v) = v { v } else { panic!() }) + .collect(); + let Some(Val::List(pb)) = results.next() else { + panic!() + }; + let pb = pb + .into_iter() + .map(|v| if let Val::Char(v) = v { v } else { panic!() }) + .collect(); + let Some(Val::List(pc)) = results.next() else { + panic!() + }; + let pc = pc + .into_iter() + .map(|v| if let Val::String(v) = v { v } else { panic!() }) + .collect(); + + (p1, p2, p3, p4, p5, p6, p7, p8, p9, p0, pa, pb, pc) + } else { + let func = instance.get_typed_func::<(), (( + i8, + u64, + f32, + u8, + i16, + String, + Vec, + bool, + bool, + char, + Vec, + Vec, + Vec, + ),)>(&mut store, "many-results")?; + + if concurrent { + let promise = func.call_concurrent(&mut store, ()).await?; + promise.get(&mut store).await?.0 + } else { + func.call_async(&mut store, ()).await?.0 + } + }; + + assert_eq!(expected.0, actual.0); + assert_eq!(expected.1, actual.1); + assert!(expected.2.is_nan()); + assert!(actual.2.is_nan()); + assert_eq!(expected.3, actual.3); + assert_eq!(expected.4, actual.4); + assert_eq!(expected.5, actual.5); + assert_eq!(expected.6, actual.6); + assert_eq!(expected.7, actual.7); + assert_eq!(expected.8, actual.8); + assert_eq!(expected.9, actual.9); + assert_eq!(expected.10, actual.10); + assert_eq!(expected.11, actual.11); + assert_eq!(expected.12, actual.12); + + Ok(()) +} + #[test] fn some_traps() -> Result<()> { let middle_of_memory = (i32::MAX / 2) & (!0xff); diff --git a/tests/all/component_model/import.rs b/tests/all/component_model/import.rs index 2dfd6a37ce09..285e173ff522 100644 --- a/tests/all/component_model/import.rs +++ b/tests/all/component_model/import.rs @@ -3,8 +3,9 @@ use super::REALLOC_AND_FREE; use anyhow::Result; use std::ops::Deref; +use wasmtime::component; use wasmtime::component::*; -use wasmtime::{Store, StoreContextMut, Trap, WasmBacktrace}; +use wasmtime::{Config, Engine, Store, StoreContextMut, Trap, WasmBacktrace}; #[test] fn can_compile() -> Result<()> { @@ -481,37 +482,86 @@ fn attempt_to_reenter_during_host() -> Result<()> { Ok(()) } -#[test] -fn stack_and_heap_args_and_rets() -> Result<()> { - let component = format!( - r#" -(component - (type $many_params (tuple - string string string string - string string string string - string)) - (import "f1" (func $f1 (param "a" u32) (result u32))) - (import "f2" (func $f2 (param "a" $many_params) (result u32))) - (import "f3" (func $f3 (param "a" u32) (result string))) - (import "f4" (func $f4 (param "a" $many_params) (result string))) +#[tokio::test] +async fn stack_and_heap_args_and_rets() -> Result<()> { + test_stack_and_heap_args_and_rets(false).await +} - (core module $libc - {REALLOC_AND_FREE} - (memory (export "memory") 1) - ) - (core instance $libc (instantiate (module $libc))) +#[tokio::test] +async fn stack_and_heap_args_and_rets_concurrent() -> Result<()> { + test_stack_and_heap_args_and_rets(true).await +} - (core func $f1_lower (canon lower (func $f1) (memory $libc "memory") (realloc (func $libc "realloc")))) - (core func $f2_lower (canon lower (func $f2) (memory $libc "memory") (realloc (func $libc "realloc")))) - (core func $f3_lower (canon lower (func $f3) (memory $libc "memory") (realloc (func $libc "realloc")))) - (core func $f4_lower (canon lower (func $f4) (memory $libc "memory") (realloc (func $libc "realloc")))) +async fn test_stack_and_heap_args_and_rets(concurrent: bool) -> Result<()> { + let (body, async_lower_opts, async_lift_opts) = if concurrent { + ( + r#" + (import "host" "f1" (func $f1 (param i32 i32) (result i32))) + (import "host" "f2" (func $f2 (param i32 i32) (result i32))) + (import "host" "f3" (func $f3 (param i32 i32) (result i32))) + (import "host" "f4" (func $f4 (param i32 i32) (result i32))) - (core module $m + (func $run (export "run") (result i32) + (local $params i32) + (local $results i32) + + block + (local.set $params (call $realloc (i32.const 0) (i32.const 0) (i32.const 4) (i32.const 4))) + (i32.store offset=0 (local.get $params) (i32.const 1)) + (local.set $results (call $realloc (i32.const 0) (i32.const 0) (i32.const 4) (i32.const 4))) + (call $f1 (local.get $params) (local.get $results)) + drop + (i32.load offset=0 (local.get $results)) + i32.const 2 + i32.eq + br_if 0 + unreachable + end + + block + (local.set $params (call $allocate_empty_strings)) + (local.set $results (call $realloc (i32.const 0) (i32.const 0) (i32.const 4) (i32.const 4))) + (call $f2 (local.get $params) (local.get $results)) + drop + (i32.load offset=0 (local.get $results)) + i32.const 3 + i32.eq + br_if 0 + unreachable + end + + block + (local.set $params (call $realloc (i32.const 0) (i32.const 0) (i32.const 4) (i32.const 4))) + (i32.store offset=0 (local.get $params) (i32.const 8)) + (local.set $results (call $realloc (i32.const 0) (i32.const 0) (i32.const 4) (i32.const 8))) + (call $f3 (local.get $params) (local.get $results)) + drop + (call $validate_string_ret (local.get $results)) + end + + block + (local.set $params (call $allocate_empty_strings)) + (local.set $results (call $realloc (i32.const 0) (i32.const 0) (i32.const 4) (i32.const 8))) + (call $f4 (local.get $params) (local.get $results)) + drop + (call $validate_string_ret (local.get $results)) + end + + (call $task-return) + + i32.const 0 + ) + "#, + "async", + r#"async (callback (func $m "callback"))"#, + ) + } else { + ( + r#" (import "host" "f1" (func $f1 (param i32) (result i32))) (import "host" "f2" (func $f2 (param i32) (result i32))) (import "host" "f3" (func $f3 (param i32 i32))) (import "host" "f4" (func $f4 (param i32 i32))) - (import "libc" "memory" (memory 1)) (func $run (export "run") block @@ -546,6 +596,58 @@ fn stack_and_heap_args_and_rets() -> Result<()> { (call $validate_string_ret (i32.const 20000)) end ) + "#, + "", + "", + ) + }; + + let component = format!( + r#" +(component + (type $many_params (tuple + string string string string + string string string string + string)) + (import "f1" (func $f1 (param "a" u32) (result u32))) + (import "f2" (func $f2 (param "a" $many_params) (result u32))) + (import "f3" (func $f3 (param "a" u32) (result string))) + (import "f4" (func $f4 (param "a" $many_params) (result string))) + + (core module $libc + {REALLOC_AND_FREE} + (memory (export "memory") 1) + ) + (core instance $libc (instantiate (module $libc))) + + (core func $f1_lower (canon lower (func $f1) + (memory $libc "memory") + (realloc (func $libc "realloc")) + {async_lower_opts} + )) + (core func $f2_lower (canon lower (func $f2) + (memory $libc "memory") + (realloc (func $libc "realloc")) + {async_lower_opts} + )) + (core func $f3_lower (canon lower (func $f3) + (memory $libc "memory") + (realloc (func $libc "realloc")) + {async_lower_opts} + )) + (core func $f4_lower (canon lower (func $f4) + (memory $libc "memory") + (realloc (func $libc "realloc")) + {async_lower_opts} + )) + + (core module $m + (import "libc" "memory" (memory 1)) + (import "libc" "realloc" (func $realloc (param i32 i32 i32 i32) (result i32))) + (import "host" "task.return" (func $task-return)) + {body} + + (func (export "callback") (param i32 i32 i32 i32) (result i32) unreachable) (func $allocate_empty_strings (result i32) (local $ret i32) @@ -601,6 +703,7 @@ fn stack_and_heap_args_and_rets() -> Result<()> { (data (i32.const 1000) "abc") ) + (core func $task-return (canon task.return)) (core instance $m (instantiate $m (with "libc" (instance $libc)) (with "host" (instance @@ -608,130 +711,239 @@ fn stack_and_heap_args_and_rets() -> Result<()> { (export "f2" (func $f2_lower)) (export "f3" (func $f3_lower)) (export "f4" (func $f4_lower)) + (export "task.return" (func $task-return)) )) )) (func (export "run") - (canon lift (core func $m "run")) + (canon lift (core func $m "run") {async_lift_opts}) ) ) "# ); - let engine = super::engine(); + let mut config = Config::new(); + config.wasm_component_model_async(true); + config.async_support(true); + let engine = &Engine::new(&config)?; let component = Component::new(&engine, component)?; let mut store = Store::new(&engine, ()); // First, test the static API let mut linker = Linker::new(&engine); - linker - .root() - .func_wrap("f1", |_, (x,): (u32,)| -> Result<(u32,)> { - assert_eq!(x, 1); - Ok((2,)) - })?; - linker.root().func_wrap( - "f2", - |cx: StoreContextMut<'_, ()>, - (arg,): (( - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - ),)| - -> Result<(u32,)> { - assert_eq!(arg.0.to_str(&cx).unwrap(), "abc"); - Ok((3,)) - }, - )?; - linker - .root() - .func_wrap("f3", |_, (arg,): (u32,)| -> Result<(String,)> { - assert_eq!(arg, 8); - Ok(("xyz".to_string(),)) - })?; - linker.root().func_wrap( - "f4", - |cx: StoreContextMut<'_, ()>, - (arg,): (( - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - WasmStr, - ),)| - -> Result<(String,)> { - assert_eq!(arg.0.to_str(&cx).unwrap(), "abc"); - Ok(("xyz".to_string(),)) - }, - )?; - let instance = linker.instantiate(&mut store, &component)?; - instance - .get_typed_func::<(), ()>(&mut store, "run")? - .call(&mut store, ())?; + if concurrent { + linker + .root() + .func_wrap_concurrent("f1", |_, (x,): (u32,)| { + assert_eq!(x, 1); + async { component::for_any(|_| Ok((2u32,))) } + })?; + linker.root().func_wrap_concurrent( + "f2", + |cx: StoreContextMut<'_, ()>, + (arg,): (( + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + ),)| { + assert_eq!(arg.0.to_str(&cx).unwrap(), "abc"); + async { component::for_any(|_| Ok((3u32,))) } + }, + )?; + linker + .root() + .func_wrap_concurrent("f3", |_, (arg,): (u32,)| { + assert_eq!(arg, 8); + async { component::for_any(|_| Ok(("xyz".to_string(),))) } + })?; + linker.root().func_wrap_concurrent( + "f4", + |cx: StoreContextMut<'_, ()>, + (arg,): (( + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + ),)| { + assert_eq!(arg.0.to_str(&cx).unwrap(), "abc"); + async { component::for_any(|_| Ok(("xyz".to_string(),))) } + }, + )?; + } else { + linker + .root() + .func_wrap("f1", |_, (x,): (u32,)| -> Result<(u32,)> { + assert_eq!(x, 1); + Ok((2,)) + })?; + linker.root().func_wrap( + "f2", + |cx: StoreContextMut<'_, ()>, + (arg,): (( + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + ),)| + -> Result<(u32,)> { + assert_eq!(arg.0.to_str(&cx).unwrap(), "abc"); + Ok((3,)) + }, + )?; + linker + .root() + .func_wrap("f3", |_, (arg,): (u32,)| -> Result<(String,)> { + assert_eq!(arg, 8); + Ok(("xyz".to_string(),)) + })?; + linker.root().func_wrap( + "f4", + |cx: StoreContextMut<'_, ()>, + (arg,): (( + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + WasmStr, + ),)| + -> Result<(String,)> { + assert_eq!(arg.0.to_str(&cx).unwrap(), "abc"); + Ok(("xyz".to_string(),)) + }, + )?; + } + + let instance = linker.instantiate_async(&mut store, &component).await?; + let run = instance.get_typed_func::<(), ()>(&mut store, "run")?; + + if concurrent { + let promise = run.call_concurrent(&mut store, ()).await?; + promise.get(&mut store).await?; + } else { + run.call_async(&mut store, ()).await?; + } // Next, test the dynamic API let mut linker = Linker::new(&engine); - linker.root().func_new("f1", |_, args, results| { - if let Val::U32(x) = &args[0] { - assert_eq!(*x, 1); - results[0] = Val::U32(2); - Ok(()) - } else { - panic!() - } - })?; - linker.root().func_new("f2", |_, args, results| { - if let Val::Tuple(tuple) = &args[0] { - if let Val::String(s) = &tuple[0] { - assert_eq!(s.deref(), "abc"); - results[0] = Val::U32(3); + if concurrent { + linker.root().func_new_concurrent("f1", |_, args| { + if let Val::U32(x) = &args[0] { + assert_eq!(*x, 1); + async { component::for_any(|_| Ok(vec![Val::U32(2)])) } + } else { + panic!() + } + })?; + linker.root().func_new_concurrent("f2", |_, args| { + if let Val::Tuple(tuple) = &args[0] { + if let Val::String(s) = &tuple[0] { + assert_eq!(s.deref(), "abc"); + async { component::for_any(|_| Ok(vec![Val::U32(3)])) } + } else { + panic!() + } + } else { + panic!() + } + })?; + linker.root().func_new_concurrent("f3", |_, args| { + if let Val::U32(x) = &args[0] { + assert_eq!(*x, 8); + async { component::for_any(|_| Ok(vec![Val::String("xyz".into())])) } + } else { + panic!(); + } + })?; + linker.root().func_new_concurrent("f4", |_, args| { + if let Val::Tuple(tuple) = &args[0] { + if let Val::String(s) = &tuple[0] { + assert_eq!(s.deref(), "abc"); + async { component::for_any(|_| Ok(vec![Val::String("xyz".into())])) } + } else { + panic!() + } + } else { + panic!() + } + })?; + } else { + linker.root().func_new("f1", |_, args, results| { + if let Val::U32(x) = &args[0] { + assert_eq!(*x, 1); + results[0] = Val::U32(2); Ok(()) } else { panic!() } - } else { - panic!() - } - })?; - linker.root().func_new("f3", |_, args, results| { - if let Val::U32(x) = &args[0] { - assert_eq!(*x, 8); - results[0] = Val::String("xyz".into()); - Ok(()) - } else { - panic!(); - } - })?; - linker.root().func_new("f4", |_, args, results| { - if let Val::Tuple(tuple) = &args[0] { - if let Val::String(s) = &tuple[0] { - assert_eq!(s.deref(), "abc"); + })?; + linker.root().func_new("f2", |_, args, results| { + if let Val::Tuple(tuple) = &args[0] { + if let Val::String(s) = &tuple[0] { + assert_eq!(s.deref(), "abc"); + results[0] = Val::U32(3); + Ok(()) + } else { + panic!() + } + } else { + panic!() + } + })?; + linker.root().func_new("f3", |_, args, results| { + if let Val::U32(x) = &args[0] { + assert_eq!(*x, 8); results[0] = Val::String("xyz".into()); Ok(()) + } else { + panic!(); + } + })?; + linker.root().func_new("f4", |_, args, results| { + if let Val::Tuple(tuple) = &args[0] { + if let Val::String(s) = &tuple[0] { + assert_eq!(s.deref(), "abc"); + results[0] = Val::String("xyz".into()); + Ok(()) + } else { + panic!() + } } else { panic!() } - } else { - panic!() - } - })?; - let instance = linker.instantiate(&mut store, &component)?; - instance - .get_func(&mut store, "run") - .unwrap() - .call(&mut store, &[], &mut [])?; + })?; + } + + let instance = linker.instantiate_async(&mut store, &component).await?; + let run = instance.get_func(&mut store, "run").unwrap(); + + if concurrent { + let promise = run.call_concurrent(&mut store, Vec::new()).await?; + promise.get(&mut store).await?; + } else { + run.call_async(&mut store, &[], &mut []).await?; + } Ok(()) } diff --git a/tests/misc_testsuite/component-model-async/error-context.wast b/tests/misc_testsuite/component-model-async/error-context.wast new file mode 100644 index 000000000000..e564416a5109 --- /dev/null +++ b/tests/misc_testsuite/component-model-async/error-context.wast @@ -0,0 +1,35 @@ +;;! component_model_async = true + +;; error-context.new +(component + (core module $libc (memory (export "memory") 1)) + (core instance $libc (instantiate $libc)) + (core module $m + (import "" "error-context.new" (func $error-context-new (param i32 i32) (result i32))) + ) + (core func $error-context-new (canon error-context.new (memory $libc "memory"))) + (core instance $i (instantiate $m (with "" (instance (export "error-context.new" (func $error-context-new)))))) +) + +;; error-context.debug-message +(component + (core module $libc + (func (export "realloc") (param i32 i32 i32 i32) (result i32) unreachable) + (memory (export "memory") 1) + ) + (core instance $libc (instantiate $libc)) + (core module $m + (import "" "error-context.debug-message" (func $error-context-debug-message (param i32 i32))) + ) + (core func $error-context-debug-message (canon error-context.debug-message (memory $libc "memory") (realloc (func $libc "realloc")))) + (core instance $i (instantiate $m (with "" (instance (export "error-context.debug-message" (func $error-context-debug-message)))))) +) + +;; error-context.drop +(component + (core module $m + (import "" "error-context.drop" (func $error-context-drop (param i32))) + ) + (core func $error-context-drop (canon error-context.drop)) + (core instance $i (instantiate $m (with "" (instance (export "error-context.drop" (func $error-context-drop)))))) +) diff --git a/tests/misc_testsuite/component-model-async/futures.wast b/tests/misc_testsuite/component-model-async/futures.wast new file mode 100644 index 000000000000..f1e4d4d5b940 --- /dev/null +++ b/tests/misc_testsuite/component-model-async/futures.wast @@ -0,0 +1,90 @@ +;;! component_model_async = true + +;; future.new +(component + (core module $m + (import "" "future.new" (func $future-new (result i32))) + ) + (type $future-type (future u8)) + (core func $future-new (canon future.new $future-type)) + (core instance $i (instantiate $m (with "" (instance (export "future.new" (func $future-new)))))) +) + +;; future.read +(component + (core module $libc (memory (export "memory") 1)) + (core instance $libc (instantiate $libc)) + (core module $m + (import "" "future.read" (func $future-read (param i32 i32) (result i32))) + ) + (type $future-type (future u8)) + (core func $future-read (canon future.read $future-type async (memory $libc "memory"))) + (core instance $i (instantiate $m (with "" (instance (export "future.read" (func $future-read)))))) +) + +;; future.read; with realloc +(component + (core module $libc + (func (export "realloc") (param i32 i32 i32 i32) (result i32) unreachable) + (memory (export "memory") 1) + ) + (core instance $libc (instantiate $libc)) + (core module $m + (import "" "future.read" (func $future-read (param i32 i32) (result i32))) + ) + (type $future-type (future string)) + (core func $future-read (canon future.read $future-type async (memory $libc "memory") (realloc (func $libc "realloc")))) + (core instance $i (instantiate $m (with "" (instance (export "future.read" (func $future-read)))))) +) + +;; future.write +(component + (core module $libc (memory (export "memory") 1)) + (core instance $libc (instantiate $libc)) + (core module $m + (import "" "future.write" (func $future-write (param i32 i32) (result i32))) + ) + (type $future-type (future u8)) + (core func $future-write (canon future.write $future-type async (memory $libc "memory"))) + (core instance $i (instantiate $m (with "" (instance (export "future.write" (func $future-write)))))) +) + +;; future.cancel-read +(component + (core module $m + (import "" "future.cancel-read" (func $future-cancel-read (param i32) (result i32))) + ) + (type $future-type (future u8)) + (core func $future-cancel-read (canon future.cancel-read $future-type async)) + (core instance $i (instantiate $m (with "" (instance (export "future.cancel-read" (func $future-cancel-read)))))) +) + +;; future.cancel-write +(component + (core module $m + (import "" "future.cancel-write" (func $future-cancel-write (param i32) (result i32))) + ) + (type $future-type (future u8)) + (core func $future-cancel-write (canon future.cancel-write $future-type async)) + (core instance $i (instantiate $m (with "" (instance (export "future.cancel-write" (func $future-cancel-write)))))) +) + +;; future.close-readable +(component + (core module $m + (import "" "future.close-readable" (func $future-close-readable (param i32))) + ) + (type $future-type (future u8)) + (core func $future-close-readable (canon future.close-readable $future-type)) + (core instance $i (instantiate $m (with "" (instance (export "future.close-readable" (func $future-close-readable)))))) +) + +;; future.close-writable +(component + (core module $m + (import "" "future.close-writable" (func $future-close-writable (param i32 i32))) + ) + (type $future-type (future u8)) + (core func $future-close-writable (canon future.close-writable $future-type)) + (core instance $i (instantiate $m (with "" (instance (export "future.close-writable" (func $future-close-writable)))))) +) diff --git a/tests/misc_testsuite/component-model-async/streams.wast b/tests/misc_testsuite/component-model-async/streams.wast new file mode 100644 index 000000000000..790ddec7e5f8 --- /dev/null +++ b/tests/misc_testsuite/component-model-async/streams.wast @@ -0,0 +1,90 @@ +;;! component_model_async = true + +;; stream.new +(component + (core module $m + (import "" "stream.new" (func $stream-new (result i32))) + ) + (type $stream-type (stream u8)) + (core func $stream-new (canon stream.new $stream-type)) + (core instance $i (instantiate $m (with "" (instance (export "stream.new" (func $stream-new)))))) +) + +;; stream.read +(component + (core module $libc (memory (export "memory") 1)) + (core instance $libc (instantiate $libc)) + (core module $m + (import "" "stream.read" (func $stream-read (param i32 i32 i32) (result i32))) + ) + (type $stream-type (stream u8)) + (core func $stream-read (canon stream.read $stream-type async (memory $libc "memory"))) + (core instance $i (instantiate $m (with "" (instance (export "stream.read" (func $stream-read)))))) +) + +;; stream.read; with realloc +(component + (core module $libc + (func (export "realloc") (param i32 i32 i32 i32) (result i32) unreachable) + (memory (export "memory") 1) + ) + (core instance $libc (instantiate $libc)) + (core module $m + (import "" "stream.read" (func $stream-read (param i32 i32 i32) (result i32))) + ) + (type $stream-type (stream string)) + (core func $stream-read (canon stream.read $stream-type async (memory $libc "memory") (realloc (func $libc "realloc")))) + (core instance $i (instantiate $m (with "" (instance (export "stream.read" (func $stream-read)))))) +) + +;; stream.write +(component + (core module $libc (memory (export "memory") 1)) + (core instance $libc (instantiate $libc)) + (core module $m + (import "" "stream.write" (func $stream-write (param i32 i32 i32) (result i32))) + ) + (type $stream-type (stream u8)) + (core func $stream-write (canon stream.write $stream-type async (memory $libc "memory"))) + (core instance $i (instantiate $m (with "" (instance (export "stream.write" (func $stream-write)))))) +) + +;; stream.cancel-read +(component + (core module $m + (import "" "stream.cancel-read" (func $stream-cancel-read (param i32) (result i32))) + ) + (type $stream-type (stream u8)) + (core func $stream-cancel-read (canon stream.cancel-read $stream-type async)) + (core instance $i (instantiate $m (with "" (instance (export "stream.cancel-read" (func $stream-cancel-read)))))) +) + +;; stream.cancel-write +(component + (core module $m + (import "" "stream.cancel-write" (func $stream-cancel-write (param i32) (result i32))) + ) + (type $stream-type (stream u8)) + (core func $stream-cancel-write (canon stream.cancel-write $stream-type async)) + (core instance $i (instantiate $m (with "" (instance (export "stream.cancel-write" (func $stream-cancel-write)))))) +) + +;; stream.close-readable +(component + (core module $m + (import "" "stream.close-readable" (func $stream-close-readable (param i32))) + ) + (type $stream-type (stream u8)) + (core func $stream-close-readable (canon stream.close-readable $stream-type)) + (core instance $i (instantiate $m (with "" (instance (export "stream.close-readable" (func $stream-close-readable)))))) +) + +;; stream.close-writable +(component + (core module $m + (import "" "stream.close-writable" (func $stream-close-writable (param i32 i32))) + ) + (type $stream-type (stream u8)) + (core func $stream-close-writable (canon stream.close-writable $stream-type)) + (core instance $i (instantiate $m (with "" (instance (export "stream.close-writable" (func $stream-close-writable)))))) +) diff --git a/tests/misc_testsuite/component-model-async/task-builtins.wast b/tests/misc_testsuite/component-model-async/task-builtins.wast index a5f9ca0f468e..d68a56709326 100644 --- a/tests/misc_testsuite/component-model-async/task-builtins.wast +++ b/tests/misc_testsuite/component-model-async/task-builtins.wast @@ -14,8 +14,7 @@ (core module $m (import "" "task.return" (func $task-return (param i32))) ) - (core type $task-return-type (func (param i32))) - (core func $task-return (canon task.return $task-return-type)) + (core func $task-return (canon task.return (result u32))) (core instance $i (instantiate $m (with "" (instance (export "task.return" (func $task-return)))))) )