Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Go: Fix channel passing from Go to Rust by using runtime.Pinner or cgo.Handle #3208

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Go: Add `XINFO CONSUMERS` ([#3120](https://github.com/valkey-io/valkey-glide/pull/3120))
* Go: Add `XINFO GROUPS` ([#3106](https://github.com/valkey-io/valkey-glide/pull/3106))
* Go: Add `ZInterCard` ([#3078](https://github.com/valkey-io/valkey-glide/issues/3078))
* Go: Fix channel passing from Go to Rust by using `runtime.Pinner` or `cgo.Handle` ([#3208](https://github.com/valkey-io/valkey-glide/pull/3208))

#### Breaking Changes

Expand Down
16 changes: 10 additions & 6 deletions go/api/base_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ type payload struct {
//export successCallback
func successCallback(channelPtr unsafe.Pointer, cResponse *C.struct_CommandResponse) {
response := cResponse
resultChannel := *(*chan payload)(channelPtr)
resultChannel := *(*chan payload)(getPinnedPtr(channelPtr))
resultChannel <- payload{value: response, error: nil}
}

//export failureCallback
func failureCallback(channelPtr unsafe.Pointer, cErrorMessage *C.char, cErrorType C.RequestErrorType) {
defer C.free_error_message(cErrorMessage)
msg := C.GoString(cErrorMessage)
resultChannel := *(*chan payload)(channelPtr)
resultChannel := *(*chan payload)(getPinnedPtr(channelPtr))
resultChannel <- payload{value: nil, error: errors.GoError(uint32(cErrorType), msg)}
}

Expand Down Expand Up @@ -209,9 +209,6 @@ func (client *baseClient) executeCommandWithRoute(
argLengthsPtr = &argLengths[0]
}

resultChannel := make(chan payload)
resultChannelPtr := uintptr(unsafe.Pointer(&resultChannel))

var routeBytesPtr *C.uchar = nil
var routeBytesCount C.uintptr_t = 0
if route != nil {
Expand All @@ -228,9 +225,16 @@ func (client *baseClient) executeCommandWithRoute(
routeBytesPtr = (*C.uchar)(C.CBytes(msg))
}

resultChannel := make(chan payload)
resultChannelPtr := unsafe.Pointer(&resultChannel)

pinner := pinner{}
pinnedChannelPtr := pinner.Pin(resultChannelPtr)
defer pinner.Unpin()

C.command(
client.coreClient,
C.uintptr_t(resultChannelPtr),
C.uintptr_t(uintptr(pinnedChannelPtr)),
uint32(requestType),
C.size_t(len(args)),
cArgsPtr,
Expand Down
30 changes: 30 additions & 0 deletions go/api/pinner.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0

//go:build go1.21

package api

import (
"runtime"
"unsafe"
)

// pinner is a wrapper of a runtime.Pinner making the interface
// compatible to the cgo.Handle in the Go < 1.21.
// Note that this make a pinner can only hold one unsafe.Pointer.
type pinner struct {
r runtime.Pinner
}

func (p *pinner) Pin(v unsafe.Pointer) unsafe.Pointer {
p.r.Pin(v)
return v
}

func (p *pinner) Unpin() {
p.r.Unpin()
}

func getPinnedPtr(v unsafe.Pointer) unsafe.Pointer {
return v
}
30 changes: 30 additions & 0 deletions go/api/pinner_old.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0

//go:build !go1.21

package api

import (
"runtime/cgo"
"unsafe"
)

// pinner is a wrapper of a cgo.Handle making the interface
// compatible to the runtime.Pinner in the Go >= 1.21.
// Note that a pinner can only hold one unsafe.Pointer.
type pinner struct {
h cgo.Handle
}

func (p *pinner) Pin(v unsafe.Pointer) unsafe.Pointer {
p.h = cgo.NewHandle(v)
return unsafe.Pointer(&p.h)
}

func (p *pinner) Unpin() {
p.h.Delete()
}

func getPinnedPtr(v unsafe.Pointer) unsafe.Pointer {
return (*(*cgo.Handle)(v)).Value().(unsafe.Pointer)
}
20 changes: 20 additions & 0 deletions go/api/pinner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0

package api

import (
"testing"
"unsafe"
)

func TestPinner(t *testing.T) {
v := make(chan payload)

p := pinner{}
n := p.Pin(unsafe.Pointer(&v))
defer p.Unpin()

if *(*chan payload)(getPinnedPtr(n)) != v {
t.Fail()
}
}
43 changes: 43 additions & 0 deletions go/integTest/glide_test_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
"os/exec"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -227,6 +229,16 @@ func (suite *GlideTestSuite) runWithDefaultClients(test func(client api.BaseClie
suite.runWithClients(clients, test)
}

func (suite *GlideTestSuite) runParallelizedWithDefaultClients(
parallelism int,
count int64,
timeout time.Duration,
test func(client api.BaseClient),
) {
clients := suite.getDefaultClients()
suite.runParallelizedWithClients(clients, parallelism, count, timeout, test)
}

func (suite *GlideTestSuite) getDefaultClients() []api.BaseClient {
return []api.BaseClient{suite.defaultClient(), suite.defaultClusterClient()}
}
Expand Down Expand Up @@ -275,6 +287,37 @@ func (suite *GlideTestSuite) runWithClients(clients []api.BaseClient, test func(
}
}

func (suite *GlideTestSuite) runParallelizedWithClients(
clients []api.BaseClient,
parallelism int,
count int64,
timeout time.Duration,
test func(client api.BaseClient),
) {
for _, client := range clients {
suite.T().Run(fmt.Sprintf("%T", client)[5:], func(t *testing.T) {
done := make(chan struct{}, parallelism)
for i := 0; i < parallelism; i++ {
go func() {
defer func() { done <- struct{}{} }()
for !suite.T().Failed() && atomic.AddInt64(&count, -1) > 0 {
test(client)
}
}()
}
tm := time.NewTimer(timeout)
defer tm.Stop()
for i := 0; i < parallelism; i++ {
select {
case <-done:
case <-tm.C:
suite.T().Fatalf("parallelized test timeout in %s", timeout)
}
}
})
}
}

func (suite *GlideTestSuite) verifyOK(result string, err error) {
assert.Nil(suite.T(), err)
assert.Equal(suite.T(), api.OK, result)
Expand Down
21 changes: 21 additions & 0 deletions go/integTest/parallelized_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0

package integTest

import (
"runtime"
"time"

"github.com/google/uuid"
"github.com/valkey-io/valkey-glide/go/api"
)

func (suite *GlideTestSuite) TestParallelizedSetWithGC() {
// The insane 640 parallelism is required to reproduce https://github.com/valkey-io/valkey-glide/issues/3207.
suite.runParallelizedWithDefaultClients(640, 640000, time.Minute, func(client api.BaseClient) {
runtime.GC()
key := uuid.New().String()
value := uuid.New().String()
suite.verifyOK(client.Set(key, value))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check that this test achieve its goal without -race flag.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also catches the same mistake I made with cgo.Handle.
image

I was doing the following, which basically copied from the cgo.Handle example:

func (p *pinner) Pin(v unsafe.Pointer) unsafe.Pointer {
	p.h = cgo.NewHandle(v)
	return unsafe.Pointer(&p.h)
}

But the &p.h suffers from the same problem as the payload channel. I have updated a corrected version:

func (p *pinner) Pin(v unsafe.Pointer) unsafe.Pointer {
	p.h = cgo.NewHandle(v)
	return unsafe.Pointer(p.h) // Note that unsafe.Pointer(&p.h) is incorrect.
}

Copy link
Author

@rueian rueian Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, unsafe.Pointer(p.h) is not safe either, since a cgo.Handle is generated by an atomic counter and is not a valid memory address. Converting it to an unsafe.Pointer could lead the Go runtime to do some wrong things.

Therefore, in the latest version, I changed the pinner interface to return an uintptr instead and also updated callback signatures accordingly. f4529b9 Well, go vet is not happy with uintptr->unsafe.Pointer conversion either 🤔. Probably need to find another way to do the conversion.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that all the pinner functions get inlined, I think it is safe to do unsafe.Pointer(p.h) because it will be cast to uintptr immediately.
image

})
}
Loading