diff --git a/encoding/protobuf/codec_registry.go b/encoding/protobuf/codec_registry.go new file mode 100644 index 000000000..f803f721c --- /dev/null +++ b/encoding/protobuf/codec_registry.go @@ -0,0 +1,169 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package protobuf + +import ( + "bytes" + "sync" + + "github.com/gogo/protobuf/jsonpb" + "github.com/gogo/protobuf/proto" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/internal/bufferpool" + "google.golang.org/grpc/encoding" + "google.golang.org/grpc/mem" +) + +// Global codec registry for encoding-based codec overrides. +// Uses gRPC's encoding.CodecV2 interface for full codec v2 compatibility. +var ( + codecRegistryMutex sync.RWMutex + codecRegistry = make(map[string]encoding.CodecV2) +) + +func init() { + // Register built-in codecs at package initialization + RegisterCodec(&protoCodec{}) + RegisterCodec(&jsonCodec{codec: newDefaultCodec()}) +} + +// RegisterCodec registers a codec for gRPC transport. +// Accepts any gRPC encoding.CodecV2 implementation. +func RegisterCodec(codec encoding.CodecV2) { + codecRegistryMutex.Lock() + defer codecRegistryMutex.Unlock() + codecRegistry[codec.Name()] = codec +} + +// getCodecForEncoding returns the registered codec for an encoding. +// For JSON encoding, uses the custom codec's marshalers if provided. +func getCodecForEncoding(encoding transport.Encoding, c *codec) encoding.CodecV2 { + // If encoding is JSON and we have a custom codec, use it + if encoding == JSONEncoding && c != nil { + return &jsonCodec{codec: c} + } + + codecRegistryMutex.RLock() + defer codecRegistryMutex.RUnlock() + + return codecRegistry[string(encoding)] +} + +// GetCodecForEncoding is the public version of getCodecForEncoding for testing/examples. +// It returns the codec registered in the registry without custom codec context. +func GetCodecForEncoding(encoding transport.Encoding) encoding.CodecV2 { + return getCodecForEncoding(encoding, nil) +} + +// GetCodecNames returns the names of all registered codecs +func GetCodecNames() []transport.Encoding { + codecRegistryMutex.RLock() + defer codecRegistryMutex.RUnlock() + + names := make([]transport.Encoding, 0, len(codecRegistry)) + for encoding := range codecRegistry { + names = append(names, transport.Encoding(encoding)) + } + return names +} + +// protoCodec implements encoding.CodecV2 for protobuf encoding +type protoCodec struct{} + +func (c *protoCodec) Marshal(v any) (mem.BufferSlice, error) { + message, ok := v.(proto.Message) + if !ok { + return nil, proto.NewRequiredNotSetError("message is not a proto.Message") + } + + data, err := proto.Marshal(message) + if err != nil { + return nil, err + } + + return mem.BufferSlice{mem.SliceBuffer(data)}, nil +} + +func (c *protoCodec) Unmarshal(data mem.BufferSlice, v any) error { + message, ok := v.(proto.Message) + if !ok { + return proto.NewRequiredNotSetError("message is not a proto.Message") + } + return proto.Unmarshal(data.Materialize(), message) +} + +func (c *protoCodec) Name() string { + return string(Encoding) +} + +// jsonCodec implements encoding.CodecV2 for JSON encoding +type jsonCodec struct { + codec *codec +} + +func (c *jsonCodec) Marshal(v any) (mem.BufferSlice, error) { + message, ok := v.(proto.Message) + if !ok { + return nil, proto.NewRequiredNotSetError("message is not a proto.Message") + } + buf := bufferpool.Get() + if err := c.codec.jsonMarshaler.Marshal(buf, message); err != nil { + bufferpool.Put(buf) + return nil, err + } + data := append([]byte(nil), buf.Bytes()...) + bufferpool.Put(buf) + return mem.BufferSlice{mem.SliceBuffer(data)}, nil +} + +func (c *jsonCodec) Unmarshal(data mem.BufferSlice, v any) error { + message, ok := v.(proto.Message) + if !ok { + return proto.NewRequiredNotSetError("message is not a proto.Message") + } + return c.codec.jsonUnmarshaler.Unmarshal(bytes.NewReader(data.Materialize()), message) +} + +func (c *jsonCodec) Name() string { + return string(JSONEncoding) +} + +// codec is a private helper struct used to hold custom marshaling behavior for JSON. +type codec struct { + jsonMarshaler *jsonpb.Marshaler + jsonUnmarshaler *jsonpb.Unmarshaler +} + +// newDefaultCodec creates the default codec used for built-in JSON encoding. +func newDefaultCodec() *codec { + return &codec{ + jsonMarshaler: &jsonpb.Marshaler{}, + jsonUnmarshaler: &jsonpb.Unmarshaler{AllowUnknownFields: true}, + } +} + +// newCodec creates a codec with a custom AnyResolver for JSON marshaling. +func newCodec(anyResolver jsonpb.AnyResolver) *codec { + return &codec{ + jsonMarshaler: &jsonpb.Marshaler{AnyResolver: anyResolver}, + jsonUnmarshaler: &jsonpb.Unmarshaler{AnyResolver: anyResolver, AllowUnknownFields: true}, + } +} diff --git a/encoding/protobuf/codec_registry_test.go b/encoding/protobuf/codec_registry_test.go new file mode 100644 index 000000000..653c0f0a2 --- /dev/null +++ b/encoding/protobuf/codec_registry_test.go @@ -0,0 +1,135 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package protobuf + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/mem" +) + +func TestCodecRegistry(t *testing.T) { + // Test codec registration and retrieval + t.Run("codecRegistration", func(t *testing.T) { + // Create a mock codec + mockCodec := &MockYARPCCodec{name: "test-codec"} + + // Register codec + RegisterCodec(mockCodec) + + // Test retrieval by codec name + retrieved := getCodecForEncoding("test-codec", nil) + assert.Equal(t, mockCodec, retrieved, "Should return registered codec") + + // Test fallback for unknown encoding + unknown := getCodecForEncoding("unknown-encoding", nil) + assert.Nil(t, unknown, "Should return nil for unknown encoding") + }) + + // Test public API + t.Run("publicAPI", func(t *testing.T) { + mockCodec := &MockYARPCCodec{name: "public-test"} + + // Register codec + RegisterCodec(mockCodec) + + // Test public GetCodecForEncoding function + retrieved := GetCodecForEncoding("public-test") + assert.Equal(t, mockCodec, retrieved, "Public API should return registered codec") + }) + + // Test thread safety (basic check) + t.Run("concurrentAccess", func(t *testing.T) { + codec1 := &MockYARPCCodec{name: "concurrent-1"} + codec2 := &MockYARPCCodec{name: "concurrent-2"} + + // Register from multiple goroutines + done := make(chan bool, 2) + + go func() { + RegisterCodec(codec1) + done <- true + }() + + go func() { + RegisterCodec(codec2) + done <- true + }() + + // Wait for both registrations + <-done + <-done + + // Verify both are registered correctly + assert.Equal(t, codec1, getCodecForEncoding("concurrent-1", nil)) + assert.Equal(t, codec2, getCodecForEncoding("concurrent-2", nil)) + }) + + // Test codec interface compliance + t.Run("codecInterface", func(t *testing.T) { + mockCodec := &MockYARPCCodec{name: "interface-test"} + + // Test Marshal + data, err := mockCodec.Marshal([]byte("test-data")) + assert.NoError(t, err) + assert.NotNil(t, data) + + // Test Unmarshal + var result []byte + bufSlice := mem.BufferSlice{mem.SliceBuffer([]byte("unmarshal-test"))} + err = mockCodec.Unmarshal(bufSlice, &result) + assert.NoError(t, err) + assert.Equal(t, []byte("unmarshal-test"), result) + + // Test Name + assert.Equal(t, "interface-test", mockCodec.Name()) + }) +} + +// MockYARPCCodec implements encoding.CodecV2 for testing +type MockYARPCCodec struct { + name string +} + +func (m *MockYARPCCodec) Marshal(v any) (mem.BufferSlice, error) { + switch value := v.(type) { + case []byte: + return mem.BufferSlice{mem.SliceBuffer(value)}, nil + default: + return nil, fmt.Errorf("expected []byte but got %T", v) + } +} + +func (m *MockYARPCCodec) Unmarshal(data mem.BufferSlice, v any) error { + switch value := v.(type) { + case *[]byte: + *value = data.Materialize() + return nil + default: + return fmt.Errorf("expected *[]byte but got %T", v) + } +} + +func (m *MockYARPCCodec) Name() string { + return m.name +} diff --git a/encoding/protobuf/error.go b/encoding/protobuf/error.go index 465d599cc..342e2264b 100644 --- a/encoding/protobuf/error.go +++ b/encoding/protobuf/error.go @@ -159,11 +159,12 @@ func createStatusWithDetail(pberr *pberror, encoding transport.Encoding, codec * pst := st.Proto() pst.Details = pberr.details - detailsBytes, cleanup, marshalErr := marshal(encoding, pst, codec) + detailsBufferSlice, marshalErr := marshal(encoding, pst, codec) if marshalErr != nil { return nil, marshalErr } - defer cleanup() + // Materialize and copy for YARPC error details + detailsBytes := detailsBufferSlice.Materialize() yarpcDet := make([]byte, len(detailsBytes)) copy(yarpcDet, detailsBytes) return yarpcerrors.Newf(pberr.code, pberr.message).WithDetails(yarpcDet), nil diff --git a/encoding/protobuf/inbound.go b/encoding/protobuf/inbound.go index a581d9599..8c61751e9 100644 --- a/encoding/protobuf/inbound.go +++ b/encoding/protobuf/inbound.go @@ -58,20 +58,16 @@ func (u *unaryHandler) Handle(ctx context.Context, transportRequest *transport.R if err := call.WriteToResponse(responseWriter); err != nil { return err } - var responseData []byte - var responseCleanup func() if response != nil { - responseData, responseCleanup, err = marshal(transportRequest.Encoding, response, u.codec) - if responseCleanup != nil { - defer responseCleanup() - } + responseData, err := marshal(transportRequest.Encoding, response, u.codec) if err != nil { return errors.ResponseBodyEncodeError(transportRequest, err) } - } - _, err = responseWriter.Write(responseData) - if err != nil { - return err + // Materialize BufferSlice to []byte for ResponseWriter.Write() + _, err = responseWriter.Write(responseData.Materialize()) + if err != nil { + return err + } } if appErr != nil { responseWriter.SetApplicationError() @@ -129,7 +125,7 @@ func (s *streamHandler) HandleStream(stream *transport.ServerStream) error { } func getProtoRequest(ctx context.Context, transportRequest *transport.Request, newRequest func() proto.Message, codec *codec) (context.Context, *apiencoding.InboundCall, proto.Message, error) { - if err := errors.ExpectEncodings(transportRequest, Encoding, JSONEncoding); err != nil { + if err := errors.ExpectEncodings(transportRequest, GetCodecNames()...); err != nil { return nil, nil, nil, err } ctx, call := apiencoding.NewInboundCall(ctx) diff --git a/encoding/protobuf/marshal.go b/encoding/protobuf/marshal.go index f994b03dd..213bd44dd 100644 --- a/encoding/protobuf/marshal.go +++ b/encoding/protobuf/marshal.go @@ -21,38 +21,15 @@ package protobuf import ( - "bytes" "io" - "sync" - "github.com/gogo/protobuf/jsonpb" "github.com/gogo/protobuf/proto" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/internal/bufferpool" "go.uber.org/yarpc/yarpcerrors" + "google.golang.org/grpc/mem" ) -var ( - _bufferPool = sync.Pool{ - New: func() interface{} { - return proto.NewBuffer(make([]byte, 1024)) - }, - } -) - -// codec is a private helper struct used to hold custom marshling behavior. -type codec struct { - jsonMarshaler *jsonpb.Marshaler - jsonUnmarshaler *jsonpb.Unmarshaler -} - -func newCodec(anyResolver jsonpb.AnyResolver) *codec { - return &codec{ - jsonMarshaler: &jsonpb.Marshaler{AnyResolver: anyResolver}, - jsonUnmarshaler: &jsonpb.Unmarshaler{AnyResolver: anyResolver, AllowUnknownFields: true}, - } -} - func unmarshal(encoding transport.Encoding, reader io.Reader, message proto.Message, codec *codec) error { buf := bufferpool.Get() defer bufferpool.Put(buf) @@ -67,61 +44,20 @@ func unmarshal(encoding transport.Encoding, reader io.Reader, message proto.Mess } func unmarshalBytes(encoding transport.Encoding, body []byte, message proto.Message, codec *codec) error { - switch encoding { - case Encoding: - return unmarshalProto(body, message, codec) - case JSONEncoding: - return unmarshalJSON(body, message, codec) - default: + customCodec := getCodecForEncoding(encoding, codec) + if customCodec == nil { return yarpcerrors.Newf(yarpcerrors.CodeInternal, "encoding.Expect should have handled encoding %q but did not", encoding) } -} - -func unmarshalProto(body []byte, message proto.Message, _ *codec) error { - return proto.Unmarshal(body, message) -} -func unmarshalJSON(body []byte, message proto.Message, codec *codec) error { - return codec.jsonUnmarshaler.Unmarshal(bytes.NewReader(body), message) + bufferSlice := mem.BufferSlice{mem.SliceBuffer(body)} + return customCodec.Unmarshal(bufferSlice, message) } -func marshal(encoding transport.Encoding, message proto.Message, codec *codec) ([]byte, func(), error) { - switch encoding { - case Encoding: - return marshalProto(message, codec) - case JSONEncoding: - return marshalJSON(message, codec) - default: - return nil, nil, yarpcerrors.Newf(yarpcerrors.CodeInternal, "encoding.Expect should have handled encoding %q but did not", encoding) +func marshal(encoding transport.Encoding, message proto.Message, codec *codec) (mem.BufferSlice, error) { + customCodec := getCodecForEncoding(encoding, codec) + if customCodec == nil { + return nil, yarpcerrors.Newf(yarpcerrors.CodeInternal, "encoding.Expect should have handled encoding %q but did not", encoding) } -} - -func marshalProto(message proto.Message, _ *codec) ([]byte, func(), error) { - protoBuffer := getBuffer() - cleanup := func() { putBuffer(protoBuffer) } - if err := protoBuffer.Marshal(message); err != nil { - cleanup() - return nil, nil, err - } - return protoBuffer.Bytes(), cleanup, nil -} - -func marshalJSON(message proto.Message, codec *codec) ([]byte, func(), error) { - buf := bufferpool.Get() - cleanup := func() { bufferpool.Put(buf) } - if err := codec.jsonMarshaler.Marshal(buf, message); err != nil { - cleanup() - return nil, nil, err - } - return buf.Bytes(), cleanup, nil -} - -func getBuffer() *proto.Buffer { - buf := _bufferPool.Get().(*proto.Buffer) - buf.Reset() - return buf -} -func putBuffer(buf *proto.Buffer) { - _bufferPool.Put(buf) + return customCodec.Marshal(message) } diff --git a/encoding/protobuf/marshal_test.go b/encoding/protobuf/marshal_test.go index 538c6ffe1..8d0487d55 100644 --- a/encoding/protobuf/marshal_test.go +++ b/encoding/protobuf/marshal_test.go @@ -32,6 +32,6 @@ import ( func TestUnhandledEncoding(t *testing.T) { assert.Equal(t, yarpcerrors.CodeInternal, yarpcerrors.FromError(unmarshal(transport.Encoding("foo"), strings.NewReader("foo"), nil, newCodec(nil))).Code()) - _, _, err := marshal(transport.Encoding("foo"), nil, newCodec(nil)) + _, err := marshal(transport.Encoding("foo"), nil, newCodec(nil)) assert.Equal(t, yarpcerrors.CodeInternal, yarpcerrors.FromError(err).Code()) } diff --git a/encoding/protobuf/outbound.go b/encoding/protobuf/outbound.go index 00170bb4a..fd26dd372 100644 --- a/encoding/protobuf/outbound.go +++ b/encoding/protobuf/outbound.go @@ -21,7 +21,6 @@ package protobuf import ( - "bytes" "context" "github.com/gogo/protobuf/jsonpb" @@ -151,15 +150,14 @@ func (c *client) buildTransportRequest(ctx context.Context, requestMethodName st return nil, nil, nil, nil, yarpcerrors.Newf(yarpcerrors.CodeInternal, "can only use encodings %q or %q, but %q was specified", Encoding, JSONEncoding, transportRequest.Encoding) } if request != nil { - requestData, cleanup, err := marshal(transportRequest.Encoding, request, c.codec) + requestData, err := marshal(transportRequest.Encoding, request, c.codec) if err != nil { - return nil, nil, nil, cleanup, errors.RequestBodyEncodeError(transportRequest, err) + return nil, nil, nil, nil, errors.RequestBodyEncodeError(transportRequest, err) } - if requestData != nil { - transportRequest.Body = bytes.NewReader(requestData) - transportRequest.BodySize = len(requestData) - } - return ctx, call, transportRequest, cleanup, nil + // Use BufferSlice reader for zero-copy when possible + transportRequest.Body = requestData.Reader() + transportRequest.BodySize = requestData.Len() + return ctx, call, transportRequest, nil, nil } return ctx, call, transportRequest, nil, nil } diff --git a/encoding/protobuf/stream.go b/encoding/protobuf/stream.go index 512c48c57..95431c422 100644 --- a/encoding/protobuf/stream.go +++ b/encoding/protobuf/stream.go @@ -21,7 +21,6 @@ package protobuf import ( - "bytes" "context" "github.com/gogo/protobuf/proto" @@ -52,28 +51,15 @@ func readFromStream( // writeToStream writes a proto.Message to a stream. func writeToStream(ctx context.Context, stream transport.Stream, message proto.Message, codec *codec) error { - messageData, cleanup, err := marshal(stream.Request().Meta.Encoding, message, codec) + messageData, err := marshal(stream.Request().Meta.Encoding, message, codec) if err != nil { return err } return stream.SendMessage( ctx, &transport.StreamMessage{ - Body: readCloser{ - Reader: bytes.NewReader(messageData), - closer: cleanup, - }, - BodySize: len(messageData), + Body: messageData.Reader(), + BodySize: messageData.Len(), }, ) } - -type readCloser struct { - *bytes.Reader - closer func() -} - -func (r readCloser) Close() error { - r.closer() - return nil -}