diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ba5f66746e7..e4ab80b766f 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -194,7 +194,12 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + NETBIRD_STORE_ENGINE=${{ matrix.store }} \ + go test -tags=devcert -p 1 \ + -exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \ + -timeout 10m $(go list ./... | grep /management) benchmark: needs: [ build-cache ] @@ -254,7 +259,12 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ + go test -tags devcert -run=^$ -bench=. \ + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ + -timeout 10m ./... api_benchmark: needs: [ build-cache ] @@ -312,9 +322,14 @@ jobs: - name: download mysql image if: matrix.store == 'mysql' run: docker pull mlsmaycon/warmed-mysql:8 - + - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=benchmark -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ + go test -tags=benchmark -run=^$ -bench=. \ + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ + -timeout 10m $(go list ./... | grep /management) api_integration_test: needs: [ build-cache ] @@ -363,7 +378,12 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -tags=integration $(go list ./... | grep /management) + run: | + CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ + NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ + go test -p 1 \ + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ + -timeout 10m -tags=integration $(go list ./... | grep /management) test_client_on_docker: needs: [ build-cache ] diff --git a/management/client/client_test.go b/management/client/client_test.go index 8bd8af8d2aa..d0fcacfad13 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -258,8 +258,11 @@ func TestClient_Sync(t *testing.T) { ch := make(chan *mgmtProto.SyncResponse, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { - err = client.Sync(context.Background(), info, func(msg *mgmtProto.SyncResponse) error { + err = client.Sync(ctx, info, func(msg *mgmtProto.SyncResponse) error { ch <- msg return nil }) diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 6fb9f6a29a9..150953f51c7 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -129,7 +129,7 @@ func TestSaveDNSSettings(t *testing.T) { account, err := initTestDNSAccount(t, am) if err != nil { - t.Error("failed to init testing account") + t.Errorf("failed to init testing account: %v", err) } err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) diff --git a/management/server/management_suite_test.go b/management/server/management_suite_test.go deleted file mode 100644 index cc99624a07e..00000000000 --- a/management/server/management_suite_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package server_test - -import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "testing" -) - -func TestManagement(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "Management Service Suite") -} diff --git a/management/server/management_test.go b/management/server/management_test.go index cfa2c138f37..f791d813c6b 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -6,12 +6,11 @@ import ( "net" "os" "runtime" - sync2 "sync" + "sync" + "testing" "time" pb "github.com/golang/protobuf/proto" //nolint - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -30,424 +29,76 @@ import ( const ( ValidSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - AccountKey = "bf1c8084-ba50-4ce7-9439-34653001fc3b" ) -var _ = Describe("Management service", func() { - var ( - addr string - s *grpc.Server - dataDir string - client mgmtProto.ManagementServiceClient - serverPubKey wgtypes.Key - conn *grpc.ClientConn - ) - - BeforeEach(func() { - level, _ := log.ParseLevel("Debug") - log.SetLevel(level) - var err error - dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - - var listener net.Listener - - config := &server.Config{} - _, err = util.ReadJson("testdata/management.json", config) - Expect(err).NotTo(HaveOccurred()) - config.Datadir = dataDir - - s, listener = startServer(config, dataDir, "testdata/store.sql") - addr = listener.Addr().String() - client, conn = createRawClient(addr) - - // s public key - resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) - Expect(err).NotTo(HaveOccurred()) - serverPubKey, err = wgtypes.ParseKey(resp.Key) - Expect(err).NotTo(HaveOccurred()) - }) - - AfterEach(func() { - s.Stop() - err := conn.Close() - Expect(err).NotTo(HaveOccurred()) - os.RemoveAll(dataDir) - }) - - Context("when calling IsHealthy endpoint", func() { - Specify("a non-error result is returned", func() { - healthy, err := client.IsHealthy(context.TODO(), &mgmtProto.Empty{}) - - Expect(err).NotTo(HaveOccurred()) - Expect(healthy).ToNot(BeNil()) - }) - }) - - Context("when calling Sync endpoint", func() { - Context("when there is a new peer registered", func() { - Specify("a proper configuration is returned", func() { - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - - syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} - encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq) - Expect(err).NotTo(HaveOccurred()) - - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - Expect(err).NotTo(HaveOccurred()) - - resp := &mgmtProto.SyncResponse{} - err = encryption.DecryptMessage(serverPubKey, key, encryptedResponse.Body, resp) - Expect(err).NotTo(HaveOccurred()) - - expectedSignalConfig := &mgmtProto.HostConfig{ - Uri: "signal.wiretrustee.com:10000", - Protocol: mgmtProto.HostConfig_HTTP, - } - expectedStunsConfig := &mgmtProto.HostConfig{ - Uri: "stun:stun.wiretrustee.com:3468", - Protocol: mgmtProto.HostConfig_UDP, - } - expectedTRUNHost := &mgmtProto.HostConfig{ - Uri: "turn:stun.wiretrustee.com:3468", - Protocol: mgmtProto.HostConfig_UDP, - } - - Expect(resp.WiretrusteeConfig.Signal).To(BeEquivalentTo(expectedSignalConfig)) - Expect(resp.WiretrusteeConfig.Stuns).To(ConsistOf(expectedStunsConfig)) - // TURN validation is special because credentials are dynamically generated - Expect(resp.WiretrusteeConfig.Turns).To(HaveLen(1)) - actualTURN := resp.WiretrusteeConfig.Turns[0] - Expect(len(actualTURN.User) > 0).To(BeTrue()) - Expect(actualTURN.HostConfig).To(BeEquivalentTo(expectedTRUNHost)) - Expect(len(resp.NetworkMap.OfflinePeers) == 0).To(BeTrue()) - }) - }) - - Context("when there are 3 peers registered under one account", func() { - Specify("a list containing other 2 peers is returned", func() { - key, _ := wgtypes.GenerateKey() - key1, _ := wgtypes.GenerateKey() - key2, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - loginPeerWithValidSetupKey(serverPubKey, key1, client) - loginPeerWithValidSetupKey(serverPubKey, key2, client) - - messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - Expect(err).NotTo(HaveOccurred()) - decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - - resp := &mgmtProto.SyncResponse{} - err = pb.Unmarshal(decryptedBytes, resp) - Expect(err).NotTo(HaveOccurred()) - - Expect(resp.GetRemotePeers()).To(HaveLen(2)) - peers := []string{resp.GetRemotePeers()[0].WgPubKey, resp.GetRemotePeers()[1].WgPubKey} - Expect(peers).To(ContainElements(key1.PublicKey().String(), key2.PublicKey().String())) - }) - }) - - Context("when there is a new peer registered", func() { - Specify("an update is returned", func() { - // register only a single peer - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - - messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - - // after the initial sync call we have 0 peer updates - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - Expect(err).NotTo(HaveOccurred()) - decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - resp := &mgmtProto.SyncResponse{} - err = pb.Unmarshal(decryptedBytes, resp) - Expect(resp.GetRemotePeers()).To(HaveLen(0)) - - wg := sync2.WaitGroup{} - wg.Add(1) - - // continue listening on updates for a peer - go func() { - err = sync.RecvMsg(encryptedResponse) - - decryptedBytes, err = encryption.Decrypt(encryptedResponse.Body, serverPubKey, key) - Expect(err).NotTo(HaveOccurred()) - resp = &mgmtProto.SyncResponse{} - err = pb.Unmarshal(decryptedBytes, resp) - wg.Done() - }() - - // register a new peer - key1, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key1, client) - - wg.Wait() - - Expect(err).NotTo(HaveOccurred()) - Expect(resp.GetRemotePeers()).To(HaveLen(1)) - Expect(resp.GetRemotePeers()[0].WgPubKey).To(BeEquivalentTo(key1.PublicKey().String())) - }) - }) - }) - - Context("when calling GetServerKey endpoint", func() { - Specify("a public Wireguard key of the service is returned", func() { - resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) - - Expect(err).NotTo(HaveOccurred()) - Expect(resp).ToNot(BeNil()) - Expect(resp.Key).ToNot(BeNil()) - Expect(resp.ExpiresAt).ToNot(BeNil()) - - // check if the key is a valid Wireguard key - key, err := wgtypes.ParseKey(resp.Key) - Expect(err).NotTo(HaveOccurred()) - Expect(key).ToNot(BeNil()) - }) - }) - - Context("when calling Login endpoint", func() { - Context("with an invalid setup key", func() { - Specify("an error is returned", func() { - key, _ := wgtypes.GenerateKey() - message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{SetupKey: "invalid setup key", - Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - - resp, err := client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: message, - }) - - Expect(err).To(HaveOccurred()) - Expect(resp).To(BeNil()) - }) - }) - - Context("with a valid setup key", func() { - It("a non error result is returned", func() { - key, _ := wgtypes.GenerateKey() - resp := loginPeerWithValidSetupKey(serverPubKey, key, client) - - Expect(resp).ToNot(BeNil()) - }) - }) - - Context("with a registered peer", func() { - It("a non error result is returned", func() { - key, _ := wgtypes.GenerateKey() - regResp := loginPeerWithValidSetupKey(serverPubKey, key, client) - Expect(regResp).NotTo(BeNil()) - - // just login without registration - message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - loginResp, err := client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: message, - }) - - Expect(err).NotTo(HaveOccurred()) - - decryptedResp := &mgmtProto.LoginResponse{} - err = encryption.DecryptMessage(serverPubKey, key, loginResp.Body, decryptedResp) - Expect(err).NotTo(HaveOccurred()) - - expectedSignalConfig := &mgmtProto.HostConfig{ - Uri: "signal.wiretrustee.com:10000", - Protocol: mgmtProto.HostConfig_HTTP, - } - expectedStunsConfig := &mgmtProto.HostConfig{ - Uri: "stun:stun.wiretrustee.com:3468", - Protocol: mgmtProto.HostConfig_UDP, - } - expectedTurnsConfig := &mgmtProto.ProtectedHostConfig{ - HostConfig: &mgmtProto.HostConfig{ - Uri: "turn:stun.wiretrustee.com:3468", - Protocol: mgmtProto.HostConfig_UDP, - }, - User: "some_user", - Password: "some_password", - } - - Expect(decryptedResp.GetWiretrusteeConfig().Signal).To(BeEquivalentTo(expectedSignalConfig)) - Expect(decryptedResp.GetWiretrusteeConfig().Stuns).To(ConsistOf(expectedStunsConfig)) - Expect(decryptedResp.GetWiretrusteeConfig().Turns).To(ConsistOf(expectedTurnsConfig)) - }) - }) - }) +type testSuite struct { + t *testing.T + addr string + grpcServer *grpc.Server + dataDir string + client mgmtProto.ManagementServiceClient + serverPubKey wgtypes.Key + conn *grpc.ClientConn +} - Context("when there are 10 peers registered under one account", func() { - Context("when there are 10 more peers registered under the same account", func() { - Specify("all of the 10 peers will get updates of 10 newly registered peers", func() { - initialPeers := 10 - additionalPeers := 10 - - var peers []wgtypes.Key - for i := 0; i < initialPeers; i++ { - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - peers = append(peers, key) - } +func setupTest(t *testing.T) *testSuite { + t.Helper() + level, _ := log.ParseLevel("Debug") + log.SetLevel(level) - wg := sync2.WaitGroup{} - wg.Add(initialPeers + initialPeers*additionalPeers) - - var clients []mgmtProto.ManagementService_SyncClient - for _, peer := range peers { - messageBytes, err := pb.Marshal(&mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}}) - Expect(err).NotTo(HaveOccurred()) - encryptedBytes, err := encryption.Encrypt(messageBytes, serverPubKey, peer) - Expect(err).NotTo(HaveOccurred()) - - // open stream - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: peer.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - clients = append(clients, sync) - - // receive stream - peer := peer - go func() { - for { - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - if err != nil { - break - } - decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, serverPubKey, peer) - Expect(err).NotTo(HaveOccurred()) - - resp := &mgmtProto.SyncResponse{} - err = pb.Unmarshal(decryptedBytes, resp) - Expect(err).NotTo(HaveOccurred()) - if len(resp.GetRemotePeers()) > 0 { - // only consider peer updates - wg.Done() - } - } - }() - } + ts := &testSuite{t: t} - time.Sleep(1 * time.Second) - for i := 0; i < additionalPeers; i++ { - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - r := rand.New(rand.NewSource(time.Now().UnixNano())) - n := r.Intn(200) - time.Sleep(time.Duration(n) * time.Millisecond) - } + var err error + ts.dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*") + if err != nil { + t.Fatalf("failed to create temp directory: %v", err) + } - wg.Wait() + config := &server.Config{} + _, err = util.ReadJson("testdata/management.json", config) + if err != nil { + t.Fatalf("failed to read management.json: %v", err) + } + config.Datadir = ts.dataDir - for _, syncClient := range clients { - err := syncClient.CloseSend() - Expect(err).NotTo(HaveOccurred()) - } - }) - }) - }) + var listener net.Listener + ts.grpcServer, listener = startServer(t, config, ts.dataDir, "testdata/store.sql") + ts.addr = listener.Addr().String() - Context("when there are peers registered under one account concurrently", func() { - Specify("then there are no duplicate IPs", func() { - initialPeers := 30 - - ipChannel := make(chan string, 20) - for i := 0; i < initialPeers; i++ { - go func() { - defer GinkgoRecover() - key, _ := wgtypes.GenerateKey() - loginPeerWithValidSetupKey(serverPubKey, key, client) - syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} - encryptedBytes, err := encryption.EncryptMessage(serverPubKey, key, syncReq) - Expect(err).NotTo(HaveOccurred()) - - // open stream - sync, err := client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ - WgPubKey: key.PublicKey().String(), - Body: encryptedBytes, - }) - Expect(err).NotTo(HaveOccurred()) - encryptedResponse := &mgmtProto.EncryptedMessage{} - err = sync.RecvMsg(encryptedResponse) - Expect(err).NotTo(HaveOccurred()) - - resp := &mgmtProto.SyncResponse{} - err = encryption.DecryptMessage(serverPubKey, key, encryptedResponse.Body, resp) - Expect(err).NotTo(HaveOccurred()) - - ipChannel <- resp.GetPeerConfig().Address - }() - } + ts.client, ts.conn = createRawClient(t, ts.addr) - ips := make(map[string]struct{}) - for ip := range ipChannel { - if _, ok := ips[ip]; ok { - Fail("found duplicate IP: " + ip) - } - ips[ip] = struct{}{} - if len(ips) == initialPeers { - break - } - } - close(ipChannel) - }) - }) - - Context("after login two peers", func() { - Specify("then they receive the same network", func() { - key, _ := wgtypes.GenerateKey() - firstLogin := loginPeerWithValidSetupKey(serverPubKey, key, client) - key, _ = wgtypes.GenerateKey() - secondLogin := loginPeerWithValidSetupKey(serverPubKey, key, client) + resp, err := ts.client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) + if err != nil { + t.Fatalf("failed to get server key: %v", err) + } - _, firstLoginNetwork, err := net.ParseCIDR(firstLogin.GetPeerConfig().GetAddress()) - Expect(err).NotTo(HaveOccurred()) - _, secondLoginNetwork, err := net.ParseCIDR(secondLogin.GetPeerConfig().GetAddress()) - Expect(err).NotTo(HaveOccurred()) + serverKey, err := wgtypes.ParseKey(resp.Key) + if err != nil { + t.Fatalf("failed to parse server key: %v", err) + } + ts.serverPubKey = serverKey - Expect(secondLoginNetwork.String()).To(BeEquivalentTo(firstLoginNetwork.String())) - }) - }) -}) + return ts +} -func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { - defer GinkgoRecover() +func tearDownTest(t *testing.T, ts *testSuite) { + t.Helper() + ts.grpcServer.Stop() + if err := ts.conn.Close(); err != nil { + t.Fatalf("failed to close client connection: %v", err) + } + if err := os.RemoveAll(ts.dataDir); err != nil { + t.Fatalf("failed to remove data directory %s: %v", ts.dataDir, err) + } +} +func loginPeerWithValidSetupKey( + t *testing.T, + serverPubKey wgtypes.Key, + key wgtypes.Key, + client mgmtProto.ManagementServiceClient, +) *mgmtProto.LoginResponse { + t.Helper() meta := &mgmtProto.PeerSystemMeta{ Hostname: key.PublicKey().String(), GoOS: runtime.GOOS, @@ -457,23 +108,30 @@ func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, clien Kernel: "kernel", WiretrusteeVersion: "", } - message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{SetupKey: ValidSetupKey, Meta: meta}) - Expect(err).NotTo(HaveOccurred()) + msgToEncrypt := &mgmtProto.LoginRequest{SetupKey: ValidSetupKey, Meta: meta} + message, err := encryption.EncryptMessage(serverPubKey, key, msgToEncrypt) + if err != nil { + t.Fatalf("failed to encrypt login request: %v", err) + } resp, err := client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ WgPubKey: key.PublicKey().String(), Body: message, }) - - Expect(err).NotTo(HaveOccurred()) + if err != nil { + t.Fatalf("login request failed: %v", err) + } loginResp := &mgmtProto.LoginResponse{} err = encryption.DecryptMessage(serverPubKey, key, resp.Body, loginResp) - Expect(err).NotTo(HaveOccurred()) + if err != nil { + t.Fatalf("failed to decrypt login response: %v", err) + } return loginResp } -func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) { +func createRawClient(t *testing.T, addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) { + t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -484,17 +142,27 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie Time: 10 * time.Second, Timeout: 2 * time.Second, })) - Expect(err).NotTo(HaveOccurred()) + if err != nil { + t.Fatalf("failed to dial gRPC server: %v", err) + } return mgmtProto.NewManagementServiceClient(conn), conn } -func startServer(config *server.Config, dataDir string, testFile string) (*grpc.Server, net.Listener) { +func startServer( + t *testing.T, + config *server.Config, + dataDir string, + testFile string, +) (*grpc.Server, net.Listener) { + t.Helper() lis, err := net.Listen("tcp", ":0") - Expect(err).NotTo(HaveOccurred()) + if err != nil { + t.Fatalf("failed to listen on a random port: %v", err) + } s := grpc.NewServer() - store, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir) + str, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } @@ -504,23 +172,511 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) if err != nil { - log.Fatalf("failed creating metrics: %v", err) + t.Fatalf("failed creating metrics: %v", err) } - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics) + accountManager, err := server.BuildManager( + context.Background(), + str, + peersUpdateManager, + nil, + "", + "netbird.selfhosted", + eventStore, + nil, + false, + server.MocIntegratedValidator{}, + metrics, + ) if err != nil { - log.Fatalf("failed creating a manager: %v", err) + t.Fatalf("failed creating an account manager: %v", err) } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) - Expect(err).NotTo(HaveOccurred()) + mgmtServer, err := server.NewServer( + context.Background(), + config, + accountManager, + settings.NewManager(str), + peersUpdateManager, + secretsManager, + nil, + nil, + ) + if err != nil { + t.Fatalf("failed creating management server: %v", err) + } + mgmtProto.RegisterManagementServiceServer(s, mgmtServer) + go func() { if err := s.Serve(lis); err != nil { - Expect(err).NotTo(HaveOccurred()) + t.Errorf("failed to serve gRPC: %v", err) + return } }() return s, lis } + +func TestIsHealthy(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + healthy, err := ts.client.IsHealthy(context.TODO(), &mgmtProto.Empty{}) + if err != nil { + t.Fatalf("IsHealthy call returned an error: %v", err) + } + if healthy == nil { + t.Fatal("IsHealthy returned a nil response") + } +} + +func TestSyncNewPeerConfiguration(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + encryptedBytes, err := encryption.EncryptMessage(ts.serverPubKey, peerKey, syncReq) + if err != nil { + t.Fatalf("failed to encrypt sync request: %v", err) + } + + syncStream, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedBytes, + }) + if err != nil { + t.Fatalf("failed to call Sync: %v", err) + } + + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = syncStream.RecvMsg(encryptedResponse) + if err != nil { + t.Fatalf("failed to receive sync response message: %v", err) + } + + resp := &mgmtProto.SyncResponse{} + err = encryption.DecryptMessage(ts.serverPubKey, peerKey, encryptedResponse.Body, resp) + if err != nil { + t.Fatalf("failed to decrypt sync response: %v", err) + } + + if resp.WiretrusteeConfig == nil { + t.Fatal("WiretrusteeConfig is nil in SyncResponse") + } + if resp.WiretrusteeConfig.Signal == nil { + t.Fatal("Expected Signal config in SyncResponse, got nil") + } + if len(resp.WiretrusteeConfig.Stuns) == 0 { + t.Fatal("Expected at least one STUN config in SyncResponse, got none") + } + if len(resp.WiretrusteeConfig.Turns) == 0 { + t.Fatal("Expected at least one TURN config in SyncResponse, got none") + } + if len(resp.NetworkMap.OfflinePeers) != 0 { + t.Fatalf("Expected 0 offline peers, got %d", len(resp.NetworkMap.OfflinePeers)) + } +} + +func TestSyncThreePeers(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + peerKey1, _ := wgtypes.GenerateKey() + peerKey2, _ := wgtypes.GenerateKey() + + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey1, ts.client) + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey2, ts.client) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + syncBytes, err := pb.Marshal(syncReq) + if err != nil { + t.Fatalf("failed to marshal sync request: %v", err) + } + encryptedBytes, err := encryption.Encrypt(syncBytes, ts.serverPubKey, peerKey) + if err != nil { + t.Fatalf("failed to encrypt sync request: %v", err) + } + + syncStream, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedBytes, + }) + if err != nil { + t.Fatalf("failed to call Sync: %v", err) + } + + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = syncStream.RecvMsg(encryptedResponse) + if err != nil { + t.Fatalf("failed to receive sync response: %v", err) + } + + decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, ts.serverPubKey, peerKey) + if err != nil { + t.Fatalf("failed to decrypt sync response: %v", err) + } + + resp := &mgmtProto.SyncResponse{} + err = pb.Unmarshal(decryptedBytes, resp) + if err != nil { + t.Fatalf("failed to unmarshal sync response: %v", err) + } + + if len(resp.GetRemotePeers()) != 2 { + t.Fatalf("expected 2 remote peers, got %d", len(resp.GetRemotePeers())) + } + + var found1, found2 bool + for _, rp := range resp.GetRemotePeers() { + if rp.WgPubKey == peerKey1.PublicKey().String() { + found1 = true + } else if rp.WgPubKey == peerKey2.PublicKey().String() { + found2 = true + } + } + if !found1 || !found2 { + t.Fatalf("did not find the expected peer keys %s, %s among %v", + peerKey1.PublicKey().String(), + peerKey2.PublicKey().String(), + resp.GetRemotePeers()) + } +} + +func TestSyncNewPeerUpdate(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + syncBytes, err := pb.Marshal(syncReq) + if err != nil { + t.Fatalf("failed to marshal sync request: %v", err) + } + + encryptedBytes, err := encryption.Encrypt(syncBytes, ts.serverPubKey, peerKey) + if err != nil { + t.Fatalf("failed to encrypt sync request: %v", err) + } + + syncStream, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedBytes, + }) + if err != nil { + t.Fatalf("failed to call Sync: %v", err) + } + + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = syncStream.RecvMsg(encryptedResponse) + if err != nil { + t.Fatalf("failed to receive first sync response: %v", err) + } + + decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, ts.serverPubKey, peerKey) + if err != nil { + t.Fatalf("failed to decrypt first sync response: %v", err) + } + + resp := &mgmtProto.SyncResponse{} + if err := pb.Unmarshal(decryptedBytes, resp); err != nil { + t.Fatalf("failed to unmarshal first sync response: %v", err) + } + + if len(resp.GetRemotePeers()) != 0 { + t.Fatalf("expected 0 remote peers at first sync, got %d", len(resp.GetRemotePeers())) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + encryptedResponse := &mgmtProto.EncryptedMessage{} + err = syncStream.RecvMsg(encryptedResponse) + if err != nil { + t.Errorf("failed to receive second sync response: %v", err) + return + } + + decryptedBytes, err := encryption.Decrypt(encryptedResponse.Body, ts.serverPubKey, peerKey) + if err != nil { + t.Errorf("failed to decrypt second sync response: %v", err) + return + } + err = pb.Unmarshal(decryptedBytes, resp) + if err != nil { + t.Errorf("failed to unmarshal second sync response: %v", err) + return + } + }() + + newPeerKey, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, newPeerKey, ts.client) + + wg.Wait() + + if len(resp.GetRemotePeers()) != 1 { + t.Fatalf("expected exactly 1 remote peer update, got %d", len(resp.GetRemotePeers())) + } + if resp.GetRemotePeers()[0].WgPubKey != newPeerKey.PublicKey().String() { + t.Fatalf("expected new peer key %s, got %s", + newPeerKey.PublicKey().String(), + resp.GetRemotePeers()[0].WgPubKey) + } +} + +func TestGetServerKey(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + resp, err := ts.client.GetServerKey(context.TODO(), &mgmtProto.Empty{}) + if err != nil { + t.Fatalf("GetServerKey returned error: %v", err) + } + if resp == nil { + t.Fatal("GetServerKey returned nil response") + } + if resp.Key == "" { + t.Fatal("GetServerKey returned empty key") + } + if resp.ExpiresAt.AsTime().IsZero() { + t.Fatal("GetServerKey returned 0 for ExpiresAt") + } + + _, err = wgtypes.ParseKey(resp.Key) + if err != nil { + t.Fatalf("GetServerKey returned an invalid WG key: %v", err) + } +} + +func TestLoginInvalidSetupKey(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + request := &mgmtProto.LoginRequest{ + SetupKey: "invalid setup key", + Meta: &mgmtProto.PeerSystemMeta{}, + } + encryptedMsg, err := encryption.EncryptMessage(ts.serverPubKey, peerKey, request) + if err != nil { + t.Fatalf("failed to encrypt login request: %v", err) + } + + resp, err := ts.client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedMsg, + }) + if err == nil { + t.Fatal("expected error for invalid setup key but got nil") + } + if resp != nil { + t.Fatalf("expected nil response for invalid setup key but got: %+v", resp) + } +} + +func TestLoginValidSetupKey(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + resp := loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + if resp == nil { + t.Fatal("loginPeerWithValidSetupKey returned nil, expected a valid response") + } +} + +func TestLoginRegisteredPeer(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + peerKey, _ := wgtypes.GenerateKey() + regResp := loginPeerWithValidSetupKey(t, ts.serverPubKey, peerKey, ts.client) + if regResp == nil { + t.Fatal("registration with valid setup key failed") + } + + loginReq := &mgmtProto.LoginRequest{Meta: &mgmtProto.PeerSystemMeta{}} + encryptedLogin, err := encryption.EncryptMessage(ts.serverPubKey, peerKey, loginReq) + if err != nil { + t.Fatalf("failed to encrypt login request: %v", err) + } + loginRespEnc, err := ts.client.Login(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: encryptedLogin, + }) + if err != nil { + t.Fatalf("login call returned an error: %v", err) + } + + loginResp := &mgmtProto.LoginResponse{} + err = encryption.DecryptMessage(ts.serverPubKey, peerKey, loginRespEnc.Body, loginResp) + if err != nil { + t.Fatalf("failed to decrypt login response: %v", err) + } + + if loginResp.GetWiretrusteeConfig() == nil { + t.Fatal("WiretrusteeConfig is nil in login response") + } + if len(loginResp.GetWiretrusteeConfig().Stuns) == 0 { + t.Fatal("Expected STUN servers in login response, got none") + } + if len(loginResp.GetWiretrusteeConfig().Turns) == 0 { + t.Fatal("Expected TURN servers in login response, got none") + } +} + +func TestSync10PeersGetUpdates(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + initialPeers := 10 + additionalPeers := 10 + + var peers []wgtypes.Key + for i := 0; i < initialPeers; i++ { + key, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, key, ts.client) + peers = append(peers, key) + } + + var wg sync.WaitGroup + wg.Add(initialPeers + initialPeers*additionalPeers) + + var syncClients []mgmtProto.ManagementService_SyncClient + for _, pk := range peers { + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + msgBytes, err := pb.Marshal(syncReq) + if err != nil { + t.Fatalf("failed to marshal SyncRequest: %v", err) + } + encBytes, err := encryption.Encrypt(msgBytes, ts.serverPubKey, pk) + if err != nil { + t.Fatalf("failed to encrypt SyncRequest: %v", err) + } + + s, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: pk.PublicKey().String(), + Body: encBytes, + }) + if err != nil { + t.Fatalf("failed to call Sync for peer: %v", err) + } + syncClients = append(syncClients, s) + + go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) { + for { + encMsg := &mgmtProto.EncryptedMessage{} + err := syncStream.RecvMsg(encMsg) + if err != nil { + return + } + decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk) + if decErr != nil { + t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr) + return + } + resp := &mgmtProto.SyncResponse{} + umErr := pb.Unmarshal(decryptedBytes, resp) + if umErr != nil { + t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr) + return + } + // We only count if there's a new peer update + if len(resp.GetRemotePeers()) > 0 { + wg.Done() + } + } + }(pk, s) + } + + time.Sleep(500 * time.Millisecond) + for i := 0; i < additionalPeers; i++ { + key, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, key, ts.client) + r := rand.New(rand.NewSource(time.Now().UnixNano())) + n := r.Intn(200) + time.Sleep(time.Duration(n) * time.Millisecond) + } + + wg.Wait() + + for _, sc := range syncClients { + err := sc.CloseSend() + if err != nil { + t.Fatalf("failed to close sync client: %v", err) + } + } +} + +func TestConcurrentPeersNoDuplicateIPs(t *testing.T) { + ts := setupTest(t) + defer tearDownTest(t, ts) + + initialPeers := 30 + ipChan := make(chan string, initialPeers) + + var wg sync.WaitGroup + wg.Add(initialPeers) + + for i := 0; i < initialPeers; i++ { + go func() { + defer wg.Done() + key, _ := wgtypes.GenerateKey() + loginPeerWithValidSetupKey(t, ts.serverPubKey, key, ts.client) + + syncReq := &mgmtProto.SyncRequest{Meta: &mgmtProto.PeerSystemMeta{}} + encryptedBytes, err := encryption.EncryptMessage(ts.serverPubKey, key, syncReq) + if err != nil { + t.Errorf("failed to encrypt sync request: %v", err) + return + } + + s, err := ts.client.Sync(context.TODO(), &mgmtProto.EncryptedMessage{ + WgPubKey: key.PublicKey().String(), + Body: encryptedBytes, + }) + if err != nil { + t.Errorf("failed to call Sync: %v", err) + return + } + + encResp := &mgmtProto.EncryptedMessage{} + if err = s.RecvMsg(encResp); err != nil { + t.Errorf("failed to receive sync response: %v", err) + return + } + + resp := &mgmtProto.SyncResponse{} + if err = encryption.DecryptMessage(ts.serverPubKey, key, encResp.Body, resp); err != nil { + t.Errorf("failed to decrypt sync response: %v", err) + return + } + ipChan <- resp.GetPeerConfig().Address + }() + } + + wg.Wait() + close(ipChan) + + ipMap := make(map[string]bool) + for ip := range ipChan { + if ipMap[ip] { + t.Fatalf("found duplicate IP: %s", ip) + } + ipMap[ip] = true + } + + // Ensure we collected all peers + if len(ipMap) != initialPeers { + t.Fatalf("expected %d unique IPs, got %d", initialPeers, len(ipMap)) + } +} diff --git a/management/server/route_test.go b/management/server/route_test.go index 1c5c56f6069..40e0f41b029 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -13,12 +13,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/management/server/activity" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - - "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 5928b45baa7..f8e9bf0c0df 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -34,40 +34,44 @@ import ( nbroute "github.com/netbirdio/netbird/route" ) -func TestSqlite_NewStore(t *testing.T) { +func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { + t.Helper() + for _, engine := range supportedEngines { + if os.Getenv("NETBIRD_STORE_ENGINE") != "" && os.Getenv("NETBIRD_STORE_ENGINE") != string(engine) { + continue + } + t.Setenv("NETBIRD_STORE_ENGINE", string(engine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), testDataFile, t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + t.Run(string(engine), func(t *testing.T) { + f(t, store) + }) + os.Unsetenv("NETBIRD_STORE_ENGINE") + } +} + +func Test_NewStore(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - if len(store.GetAllAccounts(context.Background())) != 0 { - t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") - } + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + if store == nil { + t.Errorf("expected to create a new Store") + } + if len(store.GetAllAccounts(context.Background())) != 0 { + t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") + } + }) } -func TestSqlite_SaveAccount_Large(t *testing.T) { +func Test_SaveAccount_Large(t *testing.T) { if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { t.Skip("skip CI tests on darwin and windows") } - t.Run("SQLite", func(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - runLargeTest(t, store) - }) - - // create store outside to have a better time counter for the test - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - t.Run("PostgreSQL", func(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { runLargeTest(t, store) }) } @@ -212,77 +216,74 @@ func randomIPv4() net.IP { return net.IP(b) } -func TestSqlite_SaveAccount(t *testing.T) { +func Test_SaveAccount(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := types.GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + account := newAccountWithId(context.Background(), "account_id", "testuser", "") + setupKey, _ := types.GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + account.Peers["testpeer"] = &nbpeer.Peer{ + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err := store.SaveAccount(context.Background(), account) + require.NoError(t, err) - account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = types.GenerateDefaultSetupKey() - account2.SetupKeys[setupKey.Key] = setupKey - account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") + setupKey, _ = types.GenerateDefaultSetupKey() + account2.SetupKeys[setupKey.Key] = setupKey + account2.Peers["testpeer2"] = &nbpeer.Peer{ + Key: "peerkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + } - err = store.SaveAccount(context.Background(), account2) - require.NoError(t, err) + err = store.SaveAccount(context.Background(), account2) + require.NoError(t, err) - if len(store.GetAllAccounts(context.Background())) != 2 { - t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") - } + if len(store.GetAllAccounts(context.Background())) != 2 { + t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") + } - a, err := store.GetAccount(context.Background(), account.Id) - if a == nil { - t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) - } + a, err := store.GetAccount(context.Background(), account.Id) + if a == nil { + t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) + } - if a != nil && len(a.Policies) != 1 { - t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies)) - } + if a != nil && len(a.Policies) != 1 { + t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies)) + } - if a != nil && len(a.Policies[0].Rules) != 1 { - t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules)) - return - } + if a != nil && len(a.Policies[0].Rules) != 1 { + t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules)) + return + } - if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil { - t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) - } + if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil { + t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) + } - if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil { - t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) - } + if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil { + t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) + } - if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil { - t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) - } + if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil { + t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) + } - if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil { - t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) - } + if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil { + t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) + } + }) } func TestSqlite_DeleteAccount(t *testing.T) { @@ -399,77 +400,24 @@ func TestSqlite_DeleteAccount(t *testing.T) { } } -func TestSqlite_GetAccount(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - - account, err := store.GetAccount(context.Background(), id) - require.NoError(t, err) - require.Equal(t, id, account.Id, "account id should match") - - _, err = store.GetAccount(context.Background(), "non-existing-account") - assert.Error(t, err) - parsedErr, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") -} - -func TestSqlite_SavePeer(t *testing.T) { +func Test_GetAccount(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - // save status of non-existing peer - peer := &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - ctx := context.Background() - err = store.SavePeer(ctx, account.Id, peer) - assert.Error(t, err) - parsedErr, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") - - // save new status of existing peer - account.Peers[peer.ID] = peer - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - updatedPeer := peer.Copy() - updatedPeer.Status.Connected = false - updatedPeer.Meta.Hostname = "updatedpeer" + runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) { + id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - err = store.SavePeer(ctx, account.Id, updatedPeer) - require.NoError(t, err) - - account, err = store.GetAccount(context.Background(), account.Id) - require.NoError(t, err) + account, err := store.GetAccount(context.Background(), id) + require.NoError(t, err) + require.Equal(t, id, account.Id, "account id should match") - actual := account.Peers[peer.ID] - assert.Equal(t, updatedPeer.Status, actual.Status) - assert.Equal(t, updatedPeer.Meta, actual.Meta) + _, err = store.GetAccount(context.Background(), "non-existing-account") + assert.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } func TestSqlite_SavePeerStatus(t *testing.T) { @@ -581,74 +529,65 @@ func TestSqlite_SavePeerLocation(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } -func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { +func Test_TestGetAccountByPrivateDomain(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) + runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) { + existingDomain := "test.com" - existingDomain := "test.com" + account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) + require.NoError(t, err, "should found account") + require.Equal(t, existingDomain, account.Domain, "domains should match") - account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) - require.NoError(t, err, "should found account") - require.Equal(t, existingDomain, account.Domain, "domains should match") - - _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") - require.Error(t, err, "should return error on domain lookup") - parsedErr, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") + require.Error(t, err, "should return error on domain lookup") + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } -func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { +func Test_GetTokenIDByHashedToken(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - hashed := "SoMeHaShEdToKeN" - id := "9dj38s35-63fb-11ec-90d6-0242ac120003" + runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) { + hashed := "SoMeHaShEdToKeN" + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) - require.NoError(t, err) - require.Equal(t, id, token) + token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) + require.NoError(t, err) + require.Equal(t, id, token) - _, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash") - require.Error(t, err) - parsedErr, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + _, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash") + require.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } -func TestSqlite_GetUserByTokenID(t *testing.T) { +func Test_GetUserByTokenID(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - id := "9dj38s35-63fb-11ec-90d6-0242ac120003" + runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) { + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - user, err := store.GetUserByTokenID(context.Background(), id) - require.NoError(t, err) - require.Equal(t, id, user.PATs[id].ID) + user, err := store.GetUserByTokenID(context.Background(), id) + require.NoError(t, err) + require.Equal(t, id, user.PATs[id].ID) - _, err = store.GetUserByTokenID(context.Background(), "non-existing-id") - require.Error(t, err) - parsedErr, ok := status.FromError(err) - require.True(t, ok) - require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + _, err = store.GetUserByTokenID(context.Background(), "non-existing-id") + require.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } func TestMigrate(t *testing.T) { @@ -2370,6 +2309,7 @@ func TestSqlStore_GetNetworkRouterByID(t *testing.T) { } func TestSqlStore_SaveNetworkRouter(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) diff --git a/management/server/store/store.go b/management/server/store/store.go index 91ae93c7c34..c3959f88f98 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -4,16 +4,21 @@ import ( "context" "errors" "fmt" + "math/rand" "net" "net/netip" + "net/url" "os" "path" "path/filepath" "runtime" + "slices" "strings" "time" log "github.com/sirupsen/logrus" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -177,6 +182,8 @@ const ( mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN" ) +var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine} + func getStoreEngineFromEnv() Engine { // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file. kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") @@ -185,7 +192,7 @@ func getStoreEngineFromEnv() Engine { } value := Engine(strings.ToLower(kind)) - if value == SqliteStoreEngine || value == PostgresStoreEngine || value == MysqlStoreEngine { + if slices.Contains(supportedEngines, value) { return value } @@ -333,10 +340,14 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( } func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) { + var cleanup func() if kind == PostgresStoreEngine { - cleanUp, err := testutil.CreatePostgresTestContainer() - if err != nil { - return nil, nil, err + if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" { + var err error + _, err = testutil.CreatePostgresTestContainer() + if err != nil { + return nil, nil, err + } } dsn, ok := os.LookupEnv(postgresDsnEnv) @@ -344,18 +355,29 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err) } - return store, cleanUp, nil + dsn, cleanup, err = createRandomDB(dsn, db, kind) + if err != nil { + return nil, cleanup, err + } + + store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) + if err != nil { + return nil, cleanup, err + } } if kind == MysqlStoreEngine { - cleanUp, err := testutil.CreateMysqlTestContainer() - if err != nil { - return nil, nil, err + if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" { + var err error + _, err = testutil.CreateMysqlTestContainer() + if err != nil { + return nil, nil, err + } } dsn, ok := os.LookupEnv(mysqlDsnEnv) @@ -363,21 +385,96 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) } + db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + if err != nil { + return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err) + } + + dsn, cleanup, err = createRandomDB(dsn, db, kind) + if err != nil { + return nil, cleanup, err + } + store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil) if err != nil { return nil, nil, err } - return store, cleanUp, nil } closeConnection := func() { + cleanup() store.Close(ctx) } return store, closeConnection, nil } +func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) { + dbName := fmt.Sprintf("test_db_%d", rand.Intn(1e9)) + + if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil { + return "", nil, fmt.Errorf("failed to create database: %v", err) + } + + u, err := url.Parse(dsn) + if err != nil { + return "", nil, fmt.Errorf("failed to parse DSN: %v", err) + } + + u.Path = dbName + + cleanup := func() { + switch engine { + case PostgresStoreEngine: + err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error + case MysqlStoreEngine: + // err = killMySQLConnections(dsn, dbName) + err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error + } + if err != nil { + log.Errorf("failed to drop database %s: %v", dbName, err) + panic(err) + } + sqlDB, _ := db.DB() + _ = sqlDB.Close() + } + + return u.String(), cleanup, nil + +} + +func killMySQLConnections(dsn, targetDB string) error { + u, err := url.Parse(dsn) + if err != nil { + return fmt.Errorf("failed to parse DSN: %v", err) + } + + u.Path = "testing" + + ctrlDB, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) + if err != nil { + return fmt.Errorf("failed to open mysql connection: %v", err) + } + + var procs []struct { + ID int + } + query := "SELECT ID FROM INFORMATION_SCHEMA.PROCESSLIST WHERE DB = ?" + if err := ctrlDB.Raw(query, targetDB).Scan(&procs).Error; err != nil { + return fmt.Errorf("failed to get processes: %v", err) + } + + for _, p := range procs { + killStmt := fmt.Sprintf("KILL %d", p.ID) + if err := ctrlDB.Exec(killStmt).Error; err != nil { + return fmt.Errorf("failed to kill process %d: %v", p.ID, err) + } + } + + return nil +} + func loadSQL(db *gorm.DB, filepath string) error { sqlContent, err := os.ReadFile(filepath) if err != nil { diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index 16438cab8f0..8672efa7ff4 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -22,7 +22,7 @@ func CreateMysqlTestContainer() (func(), error) { myContainer, err := mysql.RunContainer(ctx, testcontainers.WithImage("mlsmaycon/warmed-mysql:8"), mysql.WithDatabase("testing"), - mysql.WithUsername("testing"), + mysql.WithUsername("root"), mysql.WithPassword("testing"), testcontainers.WithWaitStrategy( wait.ForLog("/usr/sbin/mysqld: ready for connections"). @@ -34,6 +34,7 @@ func CreateMysqlTestContainer() (func(), error) { } cleanup := func() { + os.Unsetenv("NETBIRD_STORE_ENGINE_MYSQL_DSN") timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) defer cancelFunc() if err = myContainer.Terminate(timeoutCtx); err != nil { @@ -68,6 +69,7 @@ func CreatePostgresTestContainer() (func(), error) { } cleanup := func() { + os.Unsetenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN") timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) defer cancelFunc() if err = pgContainer.Terminate(timeoutCtx); err != nil {