Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ type Config struct {

WindowConnectedTime time.Duration `yaml:"window-connected-time"`

Setup Setup `yaml:"-"`
Setup Setup `yaml:"-"`
Encryption Encryption `yaml:"-"`
}

// ConfigFileDoesNotExistError error is returned from Get method if configuration file is expected,
Expand Down Expand Up @@ -335,7 +336,7 @@ func get(args []string, cfg *Config, l *logrus.Entry) (string, error) { //nolint
return configFileF, err
}
l.Infof("Loading configuration file %s.", configFileF)
fileCfg, err := loadFromFile(configFileF)
fileCfg, err := loadFromFile(configFileF, &cfg.Encryption)
if err != nil {
return configFileF, err
}
Expand Down Expand Up @@ -365,6 +366,10 @@ func Application(cfg *Config) (*kingpin.Application, *string) {

configFileF := app.Flag("config-file", "Configuration file path [PMM_AGENT_CONFIG_FILE]").
Envar("PMM_AGENT_CONFIG_FILE").PlaceHolder("</path/to/pmm-agent.yaml>").String()
app.Flag("config-file-key-file", "Path to the key file used to encrypt/decrypt the configuration file").
Envar("PMM_AGENT_CONFIG_FILE_KEY_FILE").StringVar(&cfg.Encryption.KeyFile)
app.Flag("config-file-key-password", "Password for the key file (if required)").
Envar("PMM_AGENT_CONFIG_FILE_KEY_PASSWORD").StringVar(&cfg.Encryption.KeyFilePassword)

app.Flag("id", "ID of this pmm-agent [PMM_AGENT_ID]").
Envar("PMM_AGENT_ID").StringVar(&cfg.ID)
Expand Down Expand Up @@ -526,7 +531,7 @@ func Application(cfg *Config) (*kingpin.Application, *string) {
// As a special case, if file does not exist, it returns ConfigFileDoesNotExistError.
// Other errors are returned if file exists, but configuration can't be loaded due to permission problems,
// YAML parsing problems, etc.
func loadFromFile(path string) (*Config, error) {
func loadFromFile(path string, enc *Encryption) (*Config, error) {
if _, err := os.Stat(path); errors.Is(err, fs.ErrNotExist) {
return nil, ConfigFileDoesNotExistError(path)
}
Expand All @@ -535,6 +540,15 @@ func loadFromFile(path string) (*Config, error) {
if err != nil {
return nil, err
}

encryptionEnabled := enc != nil && len(enc.KeyFile) != 0 && len(b) != 0
if encryptionEnabled {
b, err = enc.Decrypt(b)
if err != nil {
return nil, err
}
}

cfg := &Config{}
if err = yaml.Unmarshal(b, cfg); err != nil { //nolint:musttag // false positive
return nil, err
Expand All @@ -556,6 +570,14 @@ func SaveToFile(path string, cfg *Config, comment string) error {
}
res = append(res, "---\n"...)
res = append(res, b...)
encryptionEnabled := cfg != nil && len(cfg.Encryption.KeyFile) > 0
if encryptionEnabled {
res, err = cfg.Encryption.Encrypt(res)
if err != nil {
return err
}
}

return os.WriteFile(path, res, 0o640) //nolint:gosec
}

Expand Down
8 changes: 4 additions & 4 deletions agent/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ func TestLoadFromFile(t *testing.T) {
name := writeConfig(t, &Config{ID: "agent-id"})
t.Cleanup(func() { removeConfig(t, name) })

cfg, err := loadFromFile(name)
cfg, err := loadFromFile(name, nil)
require.NoError(t, err)
assert.Equal(t, &Config{ID: "agent-id"}, cfg)
})

t.Run("NotExist", func(t *testing.T) {
cfg, err := loadFromFile("not-exist.yaml")
cfg, err := loadFromFile("not-exist.yaml", nil)
assert.Equal(t, ConfigFileDoesNotExistError("not-exist.yaml"), err)
assert.Nil(t, cfg)
})
Expand All @@ -68,7 +68,7 @@ func TestLoadFromFile(t *testing.T) {
require.NoError(t, os.Chmod(name, 0o000))
t.Cleanup(func() { removeConfig(t, name) })

cfg, err := loadFromFile(name)
cfg, err := loadFromFile(name, nil)
require.IsType(t, (*os.PathError)(nil), err)
assert.Equal(t, "open", err.(*os.PathError).Op) //nolint:errorlint
require.EqualError(t, err.(*os.PathError).Err, "permission denied") //nolint:errorlint
Expand All @@ -80,7 +80,7 @@ func TestLoadFromFile(t *testing.T) {
require.NoError(t, os.WriteFile(name, []byte(`not YAML`), 0o666)) //nolint:gosec
t.Cleanup(func() { removeConfig(t, name) })

cfg, err := loadFromFile(name)
cfg, err := loadFromFile(name, nil)
require.IsType(t, (*yaml.TypeError)(nil), err)
require.EqualError(t, err, "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `not YAML` into config.Config")
assert.Nil(t, cfg)
Expand Down
143 changes: 143 additions & 0 deletions agent/config/encryption.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright (C) 2023 Percona LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package config

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/pem"
"errors"
"fmt"
"io"
"os"

"github.com/youmark/pkcs8"
)

// Encryption handles encryption and decryption of data using hybrid RSA + AES-GCM scheme.
type Encryption struct {
KeyFile string
KeyFilePassword string
}

const (
gcmNonceSize = 12
aesKeySize = 32
)

// Encrypt encrypts the given plaintext.
func (enc Encryption) Encrypt(plain []byte) ([]byte, error) {
priv, err := enc.readKeyFile()
if err != nil {
return nil, fmt.Errorf("unable to get RSA key from KeyFile: %w", err)
}

aesKey := make([]byte, aesKeySize)
if _, err := io.ReadFull(rand.Reader, aesKey); err != nil {
return nil, fmt.Errorf("unable to generate AES key: %w", err)
}

block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("unable to init AES: %w", err)
}
gcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize)
if err != nil {
return nil, fmt.Errorf("unable to init GCM: %w", err)
}
nonce := make([]byte, gcmNonceSize)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("unable to generate nonce: %w", err)
}

ciphertext := gcm.Seal(nil, nonce, plain, nil)

wrappedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, &priv.PublicKey, aesKey, nil)
if err != nil {
return nil, fmt.Errorf("unable to RSA-wrap AES key: %w", err)
}

out := make([]byte, 0, len(wrappedKey)+len(nonce)+len(ciphertext))
out = append(out, wrappedKey...)
out = append(out, nonce...)
out = append(out, ciphertext...)
return out, nil
}

// Decrypt decrypts the given ciphertext.
func (enc Encryption) Decrypt(in []byte) ([]byte, error) {
priv, err := enc.readKeyFile()
if err != nil {
return nil, fmt.Errorf("unable to get RSA key from KeyFile: %w", err)
}

k := priv.Size()
if len(in) < k+gcmNonceSize+1 {
return nil, errors.New("ciphertext too short")
}

wrappedKey := in[:k]
nonce := in[k : k+gcmNonceSize]
ciphertext := in[k+gcmNonceSize:]

aesKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, priv, wrappedKey, nil)
if err != nil {
return nil, fmt.Errorf("unable to RSA-unwrap AES key: %w", err)
}
if len(aesKey) != aesKeySize {
return nil, fmt.Errorf("unexpected AES key length: %d", len(aesKey))
}

block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("unable to init AES: %w", err)
}
gcm, err := cipher.NewGCMWithNonceSize(block, gcmNonceSize)
if err != nil {
return nil, fmt.Errorf("unable to init GCM: %w", err)
}

plain, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("unable to decrypt (wrong key or data tampered): %w", err)
}
return plain, nil
}

func (enc Encryption) readKeyFile() (*rsa.PrivateKey, error) {
f, err := os.ReadFile(enc.KeyFile)
if err != nil {
return nil, fmt.Errorf("unable to read KeyFile: %w", err)
}

block, _ := pem.Decode(f)
if block == nil {
return nil, errors.New("no valid private key found in a KeyFile")
}

k, err := pkcs8.ParsePKCS8PrivateKey(block.Bytes, []byte(enc.KeyFilePassword))
if err != nil {
return nil, fmt.Errorf("unable to parse private key: %w", err)
}

rsaKey, ok := k.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("private key is not RSA")
}
return rsaKey, nil
}
133 changes: 133 additions & 0 deletions agent/config/encryption_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (C) 2023 Percona LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package config

import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/youmark/pkcs8"
)

func generateRSAKey(t *testing.T) []byte {
t.Helper()

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)

privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: privateKeyBytes,
})

return privateKeyPEM
}

func generateEncryptedRSAKey(t *testing.T, password string) []byte {
t.Helper()

privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

privateKeyDER, err := pkcs8.MarshalPrivateKey(privateKey, []byte(password), nil)
require.NoError(t, err)

privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "ENCRYPTED PRIVATE KEY",
Bytes: privateKeyDER,
})

return privateKeyPEM
}

func writeKey(t *testing.T, keyname string, key []byte) string {
t.Helper()
path := filepath.Join(t.TempDir(), keyname)
require.NoError(t, os.WriteFile(path, key, 0o600))
return path
}

func TestEncryption(t *testing.T) {
t.Run("Encrypted", func(t *testing.T) {
keyPEM := generateRSAKey(t)
key := writeKey(t, "key", keyPEM)
enc := Encryption{
KeyFile: key,
}
configfilef := writeConfig(t, &Config{ID: "agent-id", Encryption: enc})
cfg, err := loadFromFile(configfilef, &enc)
require.NoError(t, err)
assert.Equal(t, &Config{ID: "agent-id"}, cfg)
})

t.Run("EncryptedPassword", func(t *testing.T) {
password := "abcdefgh"
keyPEM := generateEncryptedRSAKey(t, password)
key := writeKey(t, "key", keyPEM)
enc := Encryption{
KeyFile: key,
KeyFilePassword: password,
}
configfilef := writeConfig(t, &Config{ID: "agent-id", Encryption: enc})
cfg, err := loadFromFile(configfilef, &enc)
require.NoError(t, err)
assert.Equal(t, &Config{ID: "agent-id"}, cfg)
})

t.Run("EncryptedWrongPassword", func(t *testing.T) {
password := "abcdefgh"
keyPEM := generateEncryptedRSAKey(t, password)
key := writeKey(t, "key", keyPEM)
configfilef := writeConfig(t, &Config{ID: "agent-id", Encryption: Encryption{
KeyFile: key,
KeyFilePassword: password,
}})

cfg, err := loadFromFile(configfilef, &Encryption{
KeyFile: key,
KeyFilePassword: "hgfedcba",
})
require.EqualError(t, err, "unable to get RSA key from KeyFile: unable to parse private key: pkcs8: incorrect password")
assert.Nil(t, cfg)
})

t.Run("EncryptedWrongKey", func(t *testing.T) {
password := "abcdefgh"
key1PEM := generateEncryptedRSAKey(t, password)
key2PEM := generateRSAKey(t)
key1 := writeKey(t, "key1", key1PEM)
key2 := writeKey(t, "key2", key2PEM)

configfilef := writeConfig(t, &Config{ID: "agent-id", Encryption: Encryption{
KeyFile: key2,
}})
cfg, err := loadFromFile(configfilef, &Encryption{
KeyFile: key1,
KeyFilePassword: password,
})
require.EqualError(t, err, "unable to RSA-unwrap AES key: crypto/rsa: decryption error")
assert.Nil(t, cfg)
})
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ require (
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.2.0 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78
go.opentelemetry.io/otel v1.39.0 // indirect
go.opentelemetry.io/otel/trace v1.39.0 // indirect
golang.org/x/mod v0.31.0 // indirect
Expand Down
Loading