Skip to content

Commit 16d05ac

Browse files
committed
Fix GetSchema
It looks like `GetSchema` was abandon, there weren't any tests for it and it was completely broken with `infer`. See pulumiverse/pulumi-talos#195 for an example: panic: interface conversion: interface {} is nil, not provider.RunInfo goroutine 1 [running]: github.com/pulumi/pulumi-go-provider.GetRunInfo(...) /home/runner/go/pkg/mod/github.com/pulumi/pulumi-go-provider@v1.1.2/provider.go:607 github.com/pulumi/pulumi-go-provider/middleware/schema.(*state).generateSchema(_, {_, _}) /home/runner/go/pkg/mod/github.com/pulumi/pulumi-go-provider@v1.1.2/middleware/schema/schema.go:289 +0x9f0 github.com/pulumi/pulumi-go-provider/middleware/schema.(*state).GetSchema(0xc000ddaea0, {0x494e508, 0xc0012ee370}, {0xc0008e7c70?}) /home/runner/go/pkg/mod/github.com/pulumi/pulumi-go-provider@v1.1.2/middleware/schema/schema.go:191 +0x11b github.com/pulumi/pulumi-go-provider/middleware/cancel.Wrap.setCancel2[...].func7({0x0?}) /home/runner/go/pkg/mod/github.com/pulumi/pulumi-go-provider@v1.1.2/middleware/cancel/cancel.go:134 +0xa2 github.com/pulumi/pulumi-go-provider.GetSchema({_, _}, {_, _}, {_, _}, {0xc000ddc810, 0x0, 0xc000dbc240, 0xc000ddc840, ...}) /home/runner/go/pkg/mod/github.com/pulumi/pulumi-go-provider@v1.1.2/provider.go:557 +0x191 github.com/pulumiverse/pulumi-talos/provider/native.(*Provider).GetSpec(_, {_, _}, {_, _}, {_, _}) /home/runner/work/pulumi-talos/pulumi-talos/provider/native/provider.go:36 +0xf6 github.com/pulumi/pulumi-terraform-bridge/v3/pkg/tfgen.genPulumiSchema(_, {_, _}, {_, _}, {{0x49a43c0, 0xc000da68c0}, {0x41d1110, 0x5}, {0x0, ...}, ...}, ...) /home/runner/go/pkg/mod/github.com/pulumi/pulumi-terraform-bridge/v3@v3.116.0/pkg/tfgen/generate_schema.go:251 +0x347
1 parent 7f6c96d commit 16d05ac

File tree

2 files changed

+136
-40
lines changed

2 files changed

+136
-40
lines changed

provider.go

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ import (
2323
"encoding/json"
2424
"errors"
2525
"fmt"
26-
"io"
27-
"os"
2826
"time"
2927

3028
"github.com/blang/semver"
@@ -531,40 +529,26 @@ func RawServer(
531529
return newProvider(name, version, provider.WithDefaults())
532530
}
533531

534-
// A context which prints its diagnostics, collecting all errors.
535-
type errCollectingContext struct {
536-
context.Context
537-
errs multierror.Error
538-
info RunInfo
539-
stderr io.Writer
532+
var _ logSink = &errorPreservingLogger{}
533+
534+
// A logger which collects errors.
535+
type errorPreservingLogger struct {
536+
errs multierror.Error
537+
sink logSink
540538
}
541539

542-
func (e *errCollectingContext) Log(severity diag.Severity, msg string) {
543-
if severity == diag.Error {
540+
func (e *errorPreservingLogger) Log(ctx context.Context, urn presource.URN, sev diag.Severity, msg string) {
541+
if sev == diag.Error {
544542
e.errs.Errors = append(e.errs.Errors, errors.New(msg))
545543
}
546-
_, err := fmt.Fprintf(e.stderr, "Log(%s): %s\n", severity, msg)
547-
contract.IgnoreError(err)
544+
e.sink.Log(ctx, urn, sev, msg)
548545
}
549546

550-
func (e *errCollectingContext) Logf(severity diag.Severity, msg string, args ...any) {
551-
e.Log(severity, fmt.Sprintf(msg, args...))
552-
}
553-
554-
func (e *errCollectingContext) LogStatus(severity diag.Severity, msg string) {
555-
if severity == diag.Error {
547+
func (e *errorPreservingLogger) LogStatus(ctx context.Context, urn presource.URN, sev diag.Severity, msg string) {
548+
if sev == diag.Error {
556549
e.errs.Errors = append(e.errs.Errors, errors.New(msg))
557550
}
558-
_, err := fmt.Fprintf(e.stderr, "LogStatus(%s): %s\n", severity, msg)
559-
contract.IgnoreError(err)
560-
}
561-
562-
func (e *errCollectingContext) LogStatusf(severity diag.Severity, msg string, args ...any) {
563-
e.LogStatus(severity, fmt.Sprintf(msg, args...))
564-
}
565-
566-
func (e *errCollectingContext) RuntimeInformation() RunInfo {
567-
return e.info
551+
e.sink.LogStatus(ctx, urn, sev, msg)
568552
}
569553

570554
// GetSchema retrieves the schema from the provider by invoking GetSchema on the provider.
@@ -576,22 +560,31 @@ func (e *errCollectingContext) RuntimeInformation() RunInfo {
576560
//
577561
// pulumi package get-schema ./pulumi-resource-MYPROVIDER
578562
func GetSchema(ctx context.Context, name, version string, provider Provider) (schema.PackageSpec, error) {
579-
collectingDiag := errCollectingContext{Context: ctx, stderr: os.Stderr, info: RunInfo{
580-
PackageName: name,
581-
Version: version,
582-
}}
583-
s, err := provider.GetSchema(&collectingDiag, GetSchemaRequest{Version: 0})
584-
var errs multierror.Error
563+
564+
// Wrap GetSchema with a special logger that will extract errors.
565+
getSchema := provider.GetSchema
566+
provider.GetSchema = func(ctx context.Context, req GetSchemaRequest) (GetSchemaResponse, error) {
567+
collectingDiag := &errorPreservingLogger{sink: GetLogger(ctx).inner}
568+
ctx = context.WithValue(ctx, key.Logger, collectingDiag)
569+
resp, err := getSchema(ctx, req)
570+
if err != nil {
571+
collectingDiag.errs.Errors = append(collectingDiag.errs.Errors, fmt.Errorf("GetSchema failed: %w", err))
572+
}
573+
return resp, collectingDiag.errs.ErrorOrNil()
574+
}
575+
576+
p, err := RawServer(name, version, provider)(nil)
585577
if err != nil {
586-
errs.Errors = append(errs.Errors, err)
578+
return schema.PackageSpec{}, fmt.Errorf("constructing provider: %w", err)
587579
}
588-
errs.Errors = append(errs.Errors, collectingDiag.errs.Errors...)
589580

590-
spec := schema.PackageSpec{}
591-
if err := errs.ErrorOrNil(); err != nil {
592-
return spec, err
581+
response, err := p.GetSchema(ctx, &rpc.GetSchemaRequest{})
582+
if err != nil {
583+
return schema.PackageSpec{}, err
593584
}
594-
err = json.Unmarshal([]byte(s.Schema), &spec)
585+
586+
spec := schema.PackageSpec{}
587+
err = json.Unmarshal([]byte(response.Schema), &spec)
595588
return spec, err
596589
}
597590

provider_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
package provider
1616

1717
import (
18+
"context"
1819
"testing"
1920

2021
"github.com/stretchr/testify/assert"
22+
"github.com/stretchr/testify/require"
2123
)
2224

2325
func TestDiffResponseRPC(t *testing.T) {
@@ -107,3 +109,104 @@ func TestDiffResponseRPC(t *testing.T) {
107109
assert.Nil(t, rpcResp.Stables)
108110
})
109111
}
112+
113+
func TestGetSchema(t *testing.T) {
114+
t.Parallel()
115+
116+
t.Run("logged errors are included in returned error", func(t *testing.T) {
117+
t.Parallel()
118+
provider := Provider{
119+
GetSchema: func(ctx context.Context, req GetSchemaRequest) (GetSchemaResponse, error) {
120+
logger := GetLogger(ctx)
121+
logger.Error("first error")
122+
logger.Error("second error")
123+
return GetSchemaResponse{Schema: `{"name":"test"}`}, nil
124+
},
125+
}
126+
127+
_, err := GetSchema(t.Context(), "test", "1.0.0", provider)
128+
require.Error(t, err)
129+
assert.Contains(t, err.Error(), "first error")
130+
assert.Contains(t, err.Error(), "second error")
131+
})
132+
133+
t.Run("GetSchema function error is included in returned error", func(t *testing.T) {
134+
t.Parallel()
135+
provider := Provider{
136+
GetSchema: func(ctx context.Context, req GetSchemaRequest) (GetSchemaResponse, error) {
137+
return GetSchemaResponse{}, assert.AnError
138+
},
139+
}
140+
141+
_, err := GetSchema(t.Context(), "test", "1.0.0", provider)
142+
require.Error(t, err)
143+
assert.Contains(t, err.Error(), "GetSchema failed")
144+
assert.ErrorIs(t, err, assert.AnError)
145+
})
146+
147+
t.Run("both logged errors and function error are included", func(t *testing.T) {
148+
t.Parallel()
149+
provider := Provider{
150+
GetSchema: func(ctx context.Context, req GetSchemaRequest) (GetSchemaResponse, error) {
151+
logger := GetLogger(ctx)
152+
logger.Error("logged error")
153+
return GetSchemaResponse{}, assert.AnError
154+
},
155+
}
156+
157+
_, err := GetSchema(t.Context(), "test", "1.0.0", provider)
158+
require.Error(t, err)
159+
assert.Contains(t, err.Error(), "logged error")
160+
assert.Contains(t, err.Error(), "GetSchema failed")
161+
assert.ErrorIs(t, err, assert.AnError)
162+
})
163+
164+
t.Run("GetRunInfo is accessible in GetSchema", func(t *testing.T) {
165+
t.Parallel()
166+
var capturedRunInfo RunInfo
167+
provider := Provider{
168+
GetSchema: func(ctx context.Context, req GetSchemaRequest) (GetSchemaResponse, error) {
169+
capturedRunInfo = GetRunInfo(ctx)
170+
return GetSchemaResponse{Schema: `{"name":"test"}`}, nil
171+
},
172+
}
173+
174+
_, err := GetSchema(t.Context(), "test-package", "2.3.4", provider)
175+
require.NoError(t, err)
176+
assert.Equal(t, RunInfo{
177+
PackageName: "test-package",
178+
Version: "2.3.4",
179+
SupportsOldInputs: false,
180+
}, capturedRunInfo)
181+
})
182+
183+
t.Run("non-error logs are not included in error", func(t *testing.T) {
184+
t.Parallel()
185+
provider := Provider{
186+
GetSchema: func(ctx context.Context, req GetSchemaRequest) (GetSchemaResponse, error) {
187+
logger := GetLogger(ctx)
188+
logger.Info("info message")
189+
logger.Warning("warning message")
190+
logger.Debug("debug message")
191+
return GetSchemaResponse{Schema: `{"name":"test"}`}, nil
192+
},
193+
}
194+
195+
_, err := GetSchema(t.Context(), "test", "1.0.0", provider)
196+
require.NoError(t, err)
197+
})
198+
199+
t.Run("success with valid schema", func(t *testing.T) {
200+
t.Parallel()
201+
provider := Provider{
202+
GetSchema: func(ctx context.Context, req GetSchemaRequest) (GetSchemaResponse, error) {
203+
return GetSchemaResponse{Schema: `{"name":"mypackage","version":"1.0.0"}`}, nil
204+
},
205+
}
206+
207+
spec, err := GetSchema(t.Context(), "test", "1.0.0", provider)
208+
require.NoError(t, err)
209+
assert.Equal(t, "mypackage", spec.Name)
210+
assert.Equal(t, "1.0.0", spec.Version)
211+
})
212+
}

0 commit comments

Comments
 (0)