From 4ae0eee67678fd894e2ae8af0253741d1a465b1c Mon Sep 17 00:00:00 2001 From: Harish Chandramowli Date: Tue, 2 Mar 2021 09:38:50 -0500 Subject: [PATCH] Add code and info to mongoerror (#70) --- bsonutil.go | 19 +++++++++++++++ bsonutil_test.go | 15 ++++++++++++ cmd/sni_tester/sni_tester.go | 8 ------- inttests/int_test_utils.go | 8 ------- mongo_error.go | 45 +++++++++++++++++++++++++++++------- proxy_session.go | 2 -- 6 files changed, 71 insertions(+), 26 deletions(-) diff --git a/bsonutil.go b/bsonutil.go index d800067..aba1034 100644 --- a/bsonutil.go +++ b/bsonutil.go @@ -138,6 +138,25 @@ func GetAsBSON(elem bson.E) (bson.D, string, error) { } } +func GetAsStringArray(elem bson.E) ([]string, string, error) { + tipe := fmt.Sprintf("%T", elem.Value) + switch val := elem.Value.(type) { + case primitive.A: + res := make([]string, len(val)) + for num, raw := range []interface{}(val) { + switch fixed := raw.(type) { + case string: + res[num] = fixed + default: + return nil, tipe, NewStackErrorf("not string %T %s", raw, raw) + } + } + return res, tipe, nil + default: + return nil, tipe, NewStackErrorf("not an array %T", elem.Value) + } +} + func getAsBsonDocsArray(val []interface{}, tipe string) ([]bson.D, string, error) { a := make([]bson.D, len(val)) for num, raw := range val { diff --git a/bsonutil_test.go b/bsonutil_test.go index 812d40c..0cb950e 100644 --- a/bsonutil_test.go +++ b/bsonutil_test.go @@ -27,6 +27,21 @@ func TestBSONIndexOf(test *testing.T) { } } +func TestGetAsStringArray(test *testing.T) { + val := bson.A{"test1", "test2"} + doc := bson.E{"a", val} + res, _, _ := GetAsStringArray(doc) + if len(res) != 2 { + test.Errorf("result should of length 2, but got %v", len(res)) + } + if res[0] != "test1" { + test.Errorf("expected test1, but got %v", res[0]) + } + if res[1] != "test2" { + test.Errorf("expected test2, but got %v", res[0]) + } +} + type testWalker struct { seen []bson.E } diff --git a/cmd/sni_tester/sni_tester.go b/cmd/sni_tester/sni_tester.go index 417861d..84942e6 100644 --- a/cmd/sni_tester/sni_tester.go +++ b/cmd/sni_tester/sni_tester.go @@ -28,10 +28,6 @@ type MyInterceptor struct { ps *mongonet.ProxySession } -func (myi *MyInterceptor) GetClientMessage() mongonet.Message { - return nil -} - func (myi *MyInterceptor) sniResponse() mongonet.SimpleBSON { doc := bson.D{{"sniName", myi.ps.SSLServerName}, {"ok", 1}} raw, err := mongonet.SimpleBSONConvert(doc) @@ -41,10 +37,6 @@ func (myi *MyInterceptor) sniResponse() mongonet.SimpleBSON { return raw } -func (myi *MyInterceptor) SetClientMessage(message mongonet.Message) { - return -} - func (myi *MyInterceptor) InterceptClientToMongo(m mongonet.Message) ( mongonet.Message, mongonet.ResponseInterceptor, diff --git a/inttests/int_test_utils.go b/inttests/int_test_utils.go index c4f5daf..d478b92 100644 --- a/inttests/int_test_utils.go +++ b/inttests/int_test_utils.go @@ -216,10 +216,6 @@ type MyInterceptor struct { cursorManager *LightCursorManager } -func (myi *MyInterceptor) GetClientMessage() Message { - return nil -} - func (myi *MyInterceptor) Close() { } func (myi *MyInterceptor) TrackRequest(MessageHeader) { @@ -227,10 +223,6 @@ func (myi *MyInterceptor) TrackRequest(MessageHeader) { func (myi *MyInterceptor) TrackResponse(MessageHeader) { } -func (myi *MyInterceptor) SetClientMessage(message Message) { - return -} - func (myi *MyInterceptor) CheckConnection() error { return nil } diff --git a/mongo_error.go b/mongo_error.go index fedc398..2a4dce0 100644 --- a/mongo_error.go +++ b/mongo_error.go @@ -2,7 +2,6 @@ package mongonet import ( "fmt" - "go.mongodb.org/mongo-driver/bson" ) @@ -10,10 +9,24 @@ type MongoError struct { err error code int codeName string + labels []string } func NewMongoError(err error, code int, codeName string) MongoError { - return MongoError{err, code, codeName} + return MongoError{err, code, codeName, nil} +} + +func NewMongoErrorWithLabels(err error, code int, codeName string, labels []string) MongoError { + return MongoError{err, code, codeName, labels} +} + +func (me MongoError) HasLabel(label string) bool { + for _, val := range me.labels { + if val == label { + return true + } + } + return false } func (me MongoError) ToBSON() bson.D { @@ -30,11 +43,27 @@ func (me MongoError) ToBSON() bson.D { return doc } +func (me MongoError) GetCode() int { + return me.code +} + +func (me MongoError) GetCodeName() string { + return me.codeName +} + func (me MongoError) Error() string { - return fmt.Sprintf( - "code=%v codeName=%v errmsg = %v", - me.code, - me.codeName, - me.err.Error(), - ) + if me.err != nil { + return fmt.Sprintf( + "code=%v codeName=%v errmsg = %v", + me.code, + me.codeName, + me.err.Error(), + ) + } else { + return fmt.Sprintf( + "code=%v codeName=%v", + me.code, + me.codeName, + ) + } } diff --git a/proxy_session.go b/proxy_session.go index 9c0f0d4..e1ffcc6 100644 --- a/proxy_session.go +++ b/proxy_session.go @@ -62,8 +62,6 @@ type ProxyInterceptor interface { TrackResponse(MessageHeader) CheckConnection() error CheckConnectionInterval() time.Duration - SetClientMessage(message Message) - GetClientMessage() Message } type ProxyInterceptorFactory interface {