Skip to content

Commit 9679867

Browse files
mhmd-azeezevacchi
andauthored
fix: make sure we instantiate non-main modules (#93)
Fixes #92 --------- Signed-off-by: Edoardo Vacchi <[email protected]> Co-authored-by: Edoardo Vacchi <[email protected]>
1 parent 1e14b80 commit 9679867

14 files changed

+440
-67
lines changed

extism.go

+14-16
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@ func RuntimeVersion() string {
3737

3838
// Runtime represents the Extism plugin's runtime environment, including the underlying Wazero runtime and modules.
3939
type Runtime struct {
40-
Wazero wazero.Runtime
41-
Extism api.Module
42-
Env api.Module
43-
hasWasi bool
40+
Wazero wazero.Runtime
41+
Extism api.Module
42+
Env api.Module
4443
}
4544

4645
// PluginInstanceConfig contains configuration options for the Extism plugin.
@@ -112,13 +111,12 @@ func (l LogLevel) String() string {
112111

113112
// Plugin is used to call WASM functions
114113
type Plugin struct {
115-
close []func(ctx context.Context) error
116-
extism api.Module
117-
118-
module api.Module
119-
Timeout time.Duration
120-
Config map[string]string
121-
// NOTE: maybe we can have some nice methods for getting/setting vars
114+
close []func(ctx context.Context) error
115+
extism api.Module
116+
mainModule api.Module
117+
modules map[string]api.Module
118+
Timeout time.Duration
119+
Config map[string]string
122120
Var map[string][]byte
123121
AllowedHosts []string
124122
AllowedPaths map[string]string
@@ -138,7 +136,7 @@ func logStd(level LogLevel, message string) {
138136
}
139137

140138
func (p *Plugin) Module() *Module {
141-
return &Module{inner: p.module}
139+
return &Module{inner: p.mainModule}
142140
}
143141

144142
// SetLogger sets a custom logging callback
@@ -443,7 +441,7 @@ func (p *Plugin) GetErrorWithContext(ctx context.Context) string {
443441

444442
// FunctionExists returns true when the named function is present in the plugin's main Module
445443
func (p *Plugin) FunctionExists(name string) bool {
446-
return p.module.ExportedFunction(name) != nil
444+
return p.mainModule.ExportedFunction(name) != nil
447445
}
448446

449447
// Call a function by name with the given input, returning the output
@@ -469,15 +467,15 @@ func (p *Plugin) CallWithContext(ctx context.Context, name string, data []byte)
469467

470468
ctx = context.WithValue(ctx, InputOffsetKey("inputOffset"), intputOffset)
471469

472-
var f = p.module.ExportedFunction(name)
470+
var f = p.mainModule.ExportedFunction(name)
473471

474472
if f == nil {
475473
return 1, []byte{}, fmt.Errorf("unknown function: %s", name)
476474
} else if n := len(f.Definition().ResultTypes()); n > 1 {
477475
return 1, []byte{}, fmt.Errorf("function %s has %v results, expected 0 or 1", name, n)
478476
}
479477

480-
var isStart = name == "_start"
478+
var isStart = name == "_start" || name == "_initialize"
481479
if p.guestRuntime.init != nil && !isStart && !p.guestRuntime.initialized {
482480
err := p.guestRuntime.init(ctx)
483481
if err != nil {
@@ -501,7 +499,7 @@ func (p *Plugin) CallWithContext(ctx context.Context, name string, data []byte)
501499
if exitCode == 0 {
502500
// It's possible for the function to return 0 as an error code, even
503501
// if the module is closed.
504-
if p.module.IsClosed() {
502+
if p.mainModule.IsClosed() {
505503
return 0, nil, fmt.Errorf("module is closed")
506504
}
507505
err = nil

extism_test.go

+171-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8+
"log"
9+
"os"
10+
"strings"
11+
"sync"
12+
"testing"
13+
"time"
14+
815
observe "github.com/dylibso/observe-sdk/go"
916
"github.com/dylibso/observe-sdk/go/adapter/stdout"
1017
"github.com/stretchr/testify/assert"
@@ -13,12 +20,6 @@ import (
1320
"github.com/tetratelabs/wazero/experimental"
1421
"github.com/tetratelabs/wazero/experimental/logging"
1522
"github.com/tetratelabs/wazero/sys"
16-
"log"
17-
"os"
18-
"strings"
19-
"sync"
20-
"testing"
21-
"time"
2223
)
2324

2425
func TestWasmUrl(t *testing.T) {
@@ -1038,6 +1039,170 @@ func TestEnableExperimentalFeature(t *testing.T) {
10381039
}
10391040
}
10401041

1042+
func TestModuleLinking(t *testing.T) {
1043+
manifest := Manifest{
1044+
Wasm: []Wasm{
1045+
WasmFile{
1046+
Path: "wasm/lib.wasm",
1047+
Name: "lib",
1048+
},
1049+
WasmFile{
1050+
Path: "wasm/main.wasm",
1051+
Name: "main",
1052+
},
1053+
},
1054+
}
1055+
1056+
if plugin, ok := pluginInstance(t, manifest); ok {
1057+
defer plugin.Close(context.Background())
1058+
1059+
exit, output, err := plugin.Call("run_test", []byte("benjamin"))
1060+
1061+
if assertCall(t, err, exit) {
1062+
expected := "Hello, BENJAMIN"
1063+
1064+
actual := string(output)
1065+
1066+
assert.Equal(t, expected, actual)
1067+
}
1068+
}
1069+
}
1070+
1071+
func TestModuleLinkingMultipleInstances(t *testing.T) {
1072+
manifest := Manifest{
1073+
Wasm: []Wasm{
1074+
WasmFile{
1075+
Path: "wasm/lib.wasm",
1076+
Name: "lib",
1077+
},
1078+
WasmFile{
1079+
Path: "wasm/main.wasm",
1080+
Name: "main",
1081+
},
1082+
},
1083+
}
1084+
1085+
ctx := context.Background()
1086+
config := wasiPluginConfig()
1087+
1088+
compiledPlugin, err := NewCompiledPlugin(ctx, manifest, PluginConfig{
1089+
EnableWasi: true,
1090+
}, []HostFunction{})
1091+
1092+
if err != nil {
1093+
t.Fatalf("Could not create plugin: %v", err)
1094+
}
1095+
1096+
for i := 0; i < 3; i++ {
1097+
plugin, err := compiledPlugin.Instance(ctx, config)
1098+
if err != nil {
1099+
t.Fatalf("Could not create plugin instance: %v", err)
1100+
}
1101+
// purposefully not closing the plugin instance
1102+
1103+
for j := 0; j < 3; j++ {
1104+
1105+
exit, output, err := plugin.Call("run_test", []byte("benjamin"))
1106+
1107+
if assertCall(t, err, exit) {
1108+
expected := "Hello, BENJAMIN"
1109+
1110+
actual := string(output)
1111+
1112+
assert.Equal(t, expected, actual)
1113+
}
1114+
}
1115+
}
1116+
}
1117+
1118+
func TestCompiledModuleMultipleInstances(t *testing.T) {
1119+
manifest := Manifest{
1120+
Wasm: []Wasm{
1121+
WasmFile{
1122+
Path: "wasm/count_vowels.wasm",
1123+
Name: "main",
1124+
},
1125+
},
1126+
}
1127+
1128+
ctx := context.Background()
1129+
config := wasiPluginConfig()
1130+
1131+
compiledPlugin, err := NewCompiledPlugin(ctx, manifest, PluginConfig{
1132+
EnableWasi: true,
1133+
}, []HostFunction{})
1134+
1135+
if err != nil {
1136+
t.Fatalf("Could not create plugin: %v", err)
1137+
}
1138+
1139+
var wg sync.WaitGroup
1140+
numInstances := 300
1141+
1142+
// Create and test instances in parallel
1143+
for i := 0; i < numInstances; i++ {
1144+
wg.Add(1)
1145+
go func(instanceNum int) {
1146+
defer wg.Done()
1147+
1148+
plugin, err := compiledPlugin.Instance(ctx, config)
1149+
if err != nil {
1150+
t.Errorf("Could not create plugin instance %d: %v", instanceNum, err)
1151+
return
1152+
}
1153+
// purposefully not closing the plugin instance
1154+
1155+
// Sequential calls for this instance
1156+
for j := 0; j < 3; j++ {
1157+
exit, _, err := plugin.Call("count_vowels", []byte("benjamin"))
1158+
if err != nil {
1159+
t.Errorf("Instance %d, call %d failed: %v", instanceNum, j, err)
1160+
return
1161+
}
1162+
if exit != 0 {
1163+
t.Errorf("Instance %d, call %d returned non-zero exit code: %d", instanceNum, j, exit)
1164+
}
1165+
}
1166+
}(i)
1167+
}
1168+
wg.Wait()
1169+
}
1170+
1171+
func TestMultipleCallsOutputParallel(t *testing.T) {
1172+
manifest := manifest("count_vowels.wasm")
1173+
numInstances := 300
1174+
1175+
var wg sync.WaitGroup
1176+
1177+
// Create and test instances in parallel
1178+
for i := 0; i < numInstances; i++ {
1179+
wg.Add(1)
1180+
go func(instanceNum int) {
1181+
defer wg.Done()
1182+
1183+
if plugin, ok := pluginInstance(t, manifest); ok {
1184+
defer plugin.Close(context.Background())
1185+
1186+
// Sequential calls for this instance
1187+
exit, output1, err := plugin.Call("count_vowels", []byte("aaa"))
1188+
if !assertCall(t, err, exit) {
1189+
return
1190+
}
1191+
1192+
exit, output2, err := plugin.Call("count_vowels", []byte("bbba"))
1193+
if !assertCall(t, err, exit) {
1194+
return
1195+
}
1196+
1197+
assert.Equal(t, `{"count":3,"total":3,"vowels":"aeiouAEIOU"}`, string(output1))
1198+
assert.Equal(t, `{"count":1,"total":4,"vowels":"aeiouAEIOU"}`, string(output2))
1199+
}
1200+
}(i)
1201+
}
1202+
1203+
wg.Wait()
1204+
}
1205+
10411206
func BenchmarkInitialize(b *testing.B) {
10421207
ctx := context.Background()
10431208
cache := wazero.NewCompilationCache()

0 commit comments

Comments
 (0)