-
Notifications
You must be signed in to change notification settings - Fork 88
Expand file tree
/
Copy pathopen.go
More file actions
280 lines (242 loc) · 8.4 KB
/
open.go
File metadata and controls
280 lines (242 loc) · 8.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
/*
* SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc.
* SPDX-License-Identifier: Apache-2.0
*/
package dgo
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/url"
"strconv"
"strings"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"github.com/dgraph-io/dgo/v250/protos/api"
)
const (
dgraphScheme = "dgraph"
cloudAPIKeyParam = "apikey" // optional parameter for providing a Dgraph Cloud API key
bearerTokenParam = "bearertoken" // optional parameter for providing an access token
sslModeParam = "sslmode" // optional parameter for providing a Dgraph SSL mode
namespaceParam = "namespace" // optional parameter for providing a Dgraph namespace ID
sslModeDisable = "disable"
sslModeRequire = "require"
sslModeVerifyCA = "verify-ca"
)
type bearerCreds struct {
token string
}
func (a *bearerCreds) GetRequestMetadata(ctx context.Context, uri ...string) (
map[string]string, error) {
return map[string]string{"Authorization": fmt.Sprintf("Bearer %s", a.token)}, nil
}
func (a *bearerCreds) RequireTransportSecurity() bool {
return true
}
type clientOptions struct {
namespace uint64
gopts []grpc.DialOption
username string
password string
}
// ClientOption is a function that modifies the client options.
type ClientOption func(*clientOptions) error
// WithDgraphAPIKey will use the provided API key for authentication for Dgraph Cloud.
func WithDgraphAPIKey(apiKey string) ClientOption {
return func(o *clientOptions) error {
o.gopts = append(o.gopts, grpc.WithPerRPCCredentials(&authCreds{token: apiKey}))
return nil
}
}
// WithBearerToken uses the provided token and presents it as a Bearer Token
// in the HTTP Authorization header for authentication against a Dgraph Cluster.
func WithBearerToken(token string) ClientOption {
return func(o *clientOptions) error {
o.gopts = append(o.gopts, grpc.WithPerRPCCredentials(&bearerCreds{token: token}))
return nil
}
}
func WithSkipTLSVerify() ClientOption {
return func(o *clientOptions) error {
o.gopts = append(o.gopts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{InsecureSkipVerify: true})))
return nil
}
}
// WithSystemCertPool will use the system cert pool and set up a TLS connection with Dgraph cluster.
func WithSystemCertPool() ClientOption {
return func(o *clientOptions) error {
pool, err := x509.SystemCertPool()
if err != nil {
return fmt.Errorf("failed to create system cert pool: %w", err)
}
creds := credentials.NewClientTLSFromCert(pool, "")
o.gopts = append(o.gopts, grpc.WithTransportCredentials(creds))
return nil
}
}
// WithNamespace logs into the given namespace.
func WithNamespace(nsID uint64) ClientOption {
return func(o *clientOptions) error {
o.namespace = nsID
return nil
}
}
// WithACLCreds will use the provided username and password for ACL authentication.
func WithACLCreds(username, password string) ClientOption {
return func(o *clientOptions) error {
o.username = username
o.password = password
return nil
}
}
// WithResponseFormat sets the response format for queries. By default, the
// response format is JSON. We can also specify RDF format.
func WithResponseFormat(respFormat api.Request_RespFormat) TxnOption {
return func(o *txnOptions) error {
o.respFormat = respFormat
return nil
}
}
// WithGrpcOption will add a grpc.DialOption to the client.
// This is useful for setting custom grpc options.
func WithGrpcOption(opt grpc.DialOption) ClientOption {
return func(o *clientOptions) error {
o.gopts = append(o.gopts, opt)
return nil
}
}
// Open creates a new Dgraph client by parsing a connection string of the form:
// dgraph://<optional-login>:<optional-password>@<host>:<port>?<optional-params>
// For example `dgraph://localhost:9080?sslmode=require`
//
// Parameters:
// - apikey: a Dgraph Cloud API key for authentication
// - bearertoken: a token for bearer authentication
// - sslmode: SSL connection mode (options: disable, require, verify-ca)
// - disable: No TLS (default)
// - require: Use TLS but skip certificate verification
// - verify-ca: Use TLS and verify the certificate against system CA
//
// If credentials are provided, Open connects to the gRPC endpoint and authenticates the user.
// An error can be returned if the Dgraph cluster is not yet ready to accept requests--the text
// of the error in this case will contain the string "Please retry".
func Open(connStr string) (*Dgraph, error) {
u, err := url.Parse(connStr)
if err != nil {
return nil, fmt.Errorf("invalid connection string: %w", err)
}
params, err := url.ParseQuery(u.RawQuery)
if err != nil {
return nil, fmt.Errorf("malformed connection string: %w", err)
}
apiKey := params.Get(cloudAPIKeyParam)
bearerToken := params.Get(bearerTokenParam)
sslMode := params.Get(sslModeParam)
nsID := params.Get(namespaceParam)
if u.Scheme != dgraphScheme {
return nil, fmt.Errorf("invalid scheme: must start with %s://", dgraphScheme)
}
if apiKey != "" && bearerToken != "" {
return nil, errors.New("invalid connection string: both apikey and bearertoken cannot be provided")
}
if len(strings.Split(u.Host, ":")) != 2 {
return nil, errors.New("invalid connection string: host url must have both host and port")
}
if strings.Split(u.Host, ":")[1] == "" {
return nil, errors.New("invalid connection string: missing port after port-separator colon")
}
opts := []ClientOption{}
if apiKey != "" {
opts = append(opts, WithDgraphAPIKey(apiKey))
}
if bearerToken != "" {
opts = append(opts, WithBearerToken(bearerToken))
}
if sslMode == "" {
sslMode = sslModeDisable
}
switch sslMode {
case sslModeDisable:
opts = append(opts, WithGrpcOption(grpc.WithTransportCredentials(insecure.NewCredentials())))
case sslModeRequire:
opts = append(opts, WithSkipTLSVerify())
case sslModeVerifyCA:
opts = append(opts, WithSystemCertPool())
default:
return nil, fmt.Errorf("invalid SSL mode: %s (must be one of %s, %s, %s)",
sslMode, sslModeDisable, sslModeRequire, sslModeVerifyCA)
}
if nsID != "" {
nsID, err := strconv.ParseUint(nsID, 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid namespace ID: %w", err)
}
opts = append(opts, WithNamespace(nsID))
}
if u.User != nil {
username := u.User.Username()
password, _ := u.User.Password()
if username == "" || password == "" {
return nil, errors.New("invalid connection string: both username and password must be provided")
}
opts = append(opts, WithACLCreds(username, password))
}
return NewClient(u.Host, opts...)
}
// NewClient creates a new Dgraph client for a single endpoint.
// If ACL connection options are present, a login attempt is made
// using the supplied credentials.
func NewClient(endpoint string, opts ...ClientOption) (*Dgraph, error) {
return NewRoundRobinClient([]string{endpoint}, opts...)
}
// NewRoundRobinClient creates a new Dgraph client for a list
// of endpoints. It will round robin among the provided endpoints.
// If ACL connection options are present, a login attempt is made
// using the supplied credentials.
func NewRoundRobinClient(endpoints []string, opts ...ClientOption) (*Dgraph, error) {
co := &clientOptions{}
for _, opt := range opts {
if err := opt(co); err != nil {
return nil, err
}
}
conns := make([]*grpc.ClientConn, len(endpoints))
dc := make([]api.DgraphClient, len(endpoints))
for i, endpoint := range endpoints {
conn, err := grpc.NewClient(endpoint, co.gopts...)
if err != nil {
return nil, fmt.Errorf("failed to connect to endpoint [%s]: %w", endpoint, err)
}
conns[i] = conn
dc[i] = api.NewDgraphClient(conn)
}
d := &Dgraph{dc: dc}
if co.username != "" && co.password != "" {
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
if err := d.login(ctx, co.username, co.password, co.namespace); err != nil {
d.Close()
return nil, fmt.Errorf("failed to sign in user: %w", err)
}
}
if _, err := dc[0].CheckVersion(context.Background(), &api.Check{}); err != nil {
d.Close()
return nil, fmt.Errorf("failed to ping: %w", err)
}
return d, nil
}
// GetAPIClients returns the api.DgraphClient that is useful for advanced
// cases when grpc API that are not exposed in dgo needs to be used.
func (d *Dgraph) GetAPIClients() []api.DgraphClient {
return d.dc
}
// Close shutdown down all the connections to the Dgraph Cluster.
func (d *Dgraph) Close() {
for _, conn := range d.conns {
_ = conn.Close()
}
}