Skip to content

Commit 91824b3

Browse files
authored
added SessionProvider (#8)
* added SessionProvider * added SessionProviderFun docs and example
1 parent dce4275 commit 91824b3

6 files changed

+105
-3
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.vscode

config.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,15 @@ import (
1010
"github.com/aws/aws-sdk-go/aws/session"
1111
)
1212

13+
// SessionProviderFunc can be used to add custom AWS session setup to the gosqs.Config.
14+
// Callers simply need to implement this function type and set it as Config.SessionProvider.
15+
// If Config.SessionProvider is not set (is nil), a default provider based on AWS Key/Secret will be used.
16+
type SessionProviderFunc func(c Config) (*session.Session, error)
17+
1318
// Config defines the gosqs configuration
1419
type Config struct {
20+
// a way to provide custom session setup. A default based on key/secret will be used if not provided
21+
SessionProvider SessionProviderFunc
1522
// private key to access aws
1623
Key string
1724
// secret to access aws
@@ -108,7 +115,8 @@ func (r retryer) MaxRetries() int {
108115
return 10
109116
}
110117

111-
// newSession creates a new aws session
118+
// newSession creates a new aws session.
119+
// This will be used as the default SessionProvider if one is not set
112120
func newSession(c Config) (*session.Session, error) {
113121
//sets credentials
114122
creds := credentials.NewStaticCredentials(c.Key, c.Secret, "")

consumer.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ type consumer struct {
5757
// NewConsumer creates a new SQS instance and provides a configured consumer interface for
5858
// receiving and sending messages
5959
func NewConsumer(c Config, queueName string) (Consumer, error) {
60-
sess, err := newSession(c)
60+
if c.SessionProvider == nil {
61+
c.SessionProvider = newSession
62+
}
63+
64+
sess, err := c.SessionProvider(c)
65+
6166
if err != nil {
6267
return nil, err
6368
}

consumer_test.go

+37
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ import (
55
"testing"
66
"time"
77

8+
"github.com/aws/aws-sdk-go/aws"
9+
"github.com/aws/aws-sdk-go/aws/credentials"
10+
"github.com/aws/aws-sdk-go/aws/request"
11+
"github.com/aws/aws-sdk-go/aws/session"
812
"github.com/aws/aws-sdk-go/service/sqs"
913
)
1014

@@ -84,6 +88,39 @@ func TestNewConsumer(t *testing.T) {
8488
}
8589
}
8690

91+
func TestNewConsumerWithSessionProvider(t *testing.T) {
92+
provider := func(c Config) (*session.Session, error) {
93+
creds := credentials.NewStaticCredentials("mykey", "mysecret", "")
94+
_, err := creds.Get()
95+
if err != nil {
96+
return nil, ErrInvalidCreds.Context(err)
97+
}
98+
99+
r := &retryer{retryCount: c.RetryCount}
100+
101+
cfg := request.WithRetryer(aws.NewConfig().WithRegion("us-west2").WithCredentials(creds), r)
102+
103+
hostname := "http://localhost:4100"
104+
cfg.Endpoint = &hostname
105+
106+
return session.NewSession(cfg)
107+
}
108+
109+
conf := Config{
110+
SessionProvider: provider,
111+
Env: "dev",
112+
}
113+
114+
c, err := NewConsumer(conf, "post-worker")
115+
if err != nil {
116+
t.Fatalf("error creating consumer, got %v", err)
117+
}
118+
expected := "http://local.goaws:4100/queue/dev-post-worker"
119+
if c.(*consumer).QueueURL != expected {
120+
t.Fatalf("did not properly apply http result, expected %s, got %s", expected, c.(*consumer).QueueURL)
121+
}
122+
}
123+
87124
func TestRegisterHandler(t *testing.T) {
88125
c := getConsumer(t)
89126
a := []Adapter{}

examples/session_provider.go

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package example
2+
3+
import (
4+
"github.com/aws/aws-sdk-go/aws"
5+
"github.com/aws/aws-sdk-go/aws/credentials"
6+
"github.com/aws/aws-sdk-go/aws/session"
7+
"github.com/qhenkart/gosqs"
8+
)
9+
10+
func main_with_session_provider() {
11+
12+
// implement a custom AWS session provider function
13+
provider := func(c gosqs.Config) (*session.Session, error) {
14+
15+
// note: this implementation just hardcodes key and secret, but it could do anything
16+
creds := credentials.NewStaticCredentials("mykey", "mysecret", "")
17+
_, err := creds.Get()
18+
if err != nil {
19+
return nil, gosqs.ErrInvalidCreds.Context(err)
20+
}
21+
22+
cfg := aws.NewConfig().WithRegion("us-west-1").WithCredentials(creds)
23+
24+
hostname := "http://localhost:4150"
25+
cfg.Endpoint = &hostname
26+
27+
return session.NewSession(cfg)
28+
}
29+
30+
// create the gosqs Config with our custom SessionProviderFunc
31+
c := gosqs.Config{
32+
// for emulation only
33+
// Hostname: "http://localhost:4150",
34+
35+
SessionProvider: provider,
36+
TopicARN: "arn:aws:sns:local:000000000000:dispatcher",
37+
Region: "us-west-1",
38+
}
39+
40+
//follows the flow to see how a worker should be configured and operate
41+
initWorker(c)
42+
43+
//follows the flow to see how an http service should be configured and operate
44+
initService(c)
45+
46+
}

publisher.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@ type publisher struct {
5555

5656
// NewPublisher creates a new SQS/SNS publisher instance
5757
func NewPublisher(c Config) (Publisher, error) {
58-
sess, err := newSession(c)
58+
if c.SessionProvider == nil {
59+
c.SessionProvider = newSession
60+
}
61+
62+
sess, err := c.SessionProvider(c)
63+
5964
if err != nil {
6065
return nil, err
6166
}

0 commit comments

Comments
 (0)