Skip to content

Commit

Permalink
Adding interface methods to logical.Backend for parity (hashicorp#2242)
Browse files Browse the repository at this point in the history
  • Loading branch information
armon authored and jefferai committed Jan 7, 2017
1 parent 0148443 commit 745df0a
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 14 deletions.
14 changes: 14 additions & 0 deletions logical/framework/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ type Backend struct {
// to the backend, if required.
Clean CleanupFunc

// Invalidate is called when a keys is modified if required
Invalidate InvalidateFunc

// AuthRenew is the callback to call when a RenewRequest for an
// authentication comes in. By default, renewal won't be allowed.
// See the built-in AuthRenew helpers in lease.go for common callbacks.
Expand All @@ -92,6 +95,9 @@ type WALRollbackFunc func(*logical.Request, string, interface{}) error
// CleanupFunc is the callback for backend unload.
type CleanupFunc func()

// InvalidateFunc is the callback for backend key invalidation.
type InvalidateFunc func(string)

func (b *Backend) HandleExistenceCheck(req *logical.Request) (checkFound bool, exists bool, err error) {
b.once.Do(b.init)

Expand Down Expand Up @@ -218,12 +224,20 @@ func (b *Backend) Setup(config *logical.BackendConfig) (logical.Backend, error)
return b, nil
}

// Cleanup is used to release resources and prepare to stop the backend
func (b *Backend) Cleanup() {
if b.Clean != nil {
b.Clean()
}
}

// InvalidateKey is used to clear caches and reset internal state on key changes
func (b *Backend) InvalidateKey(key string) {
if b.Invalidate != nil {
b.Invalidate(key)
}
}

// Logger can be used to get the logger. If no logger has been set,
// the logs will be discarded.
func (b *Backend) Logger() log.Logger {
Expand Down
7 changes: 7 additions & 0 deletions logical/logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@ type Backend interface {
// existence check function was found, the item exists or not.
HandleExistenceCheck(*Request) (bool, bool, error)

// Cleanup is invoked during an unmount of a backend to allow it to
// handle any cleanup like connection closing or releasing of file handles.
Cleanup()

// InvalidateKey may be invoked when an object is modified that belongs
// to the backend. The backend can use this to clear any caches or reset
// internal state as needed.
InvalidateKey(key string)
}

// BackendConfig is provided to the factory to initialize the backend
Expand Down
6 changes: 6 additions & 0 deletions logical/storage.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package logical

import (
"errors"
"fmt"
"strings"

"github.com/hashicorp/vault/helper/jsonutil"
)

// ErrReadOnly is returned when a backend does not support
// writing. This can be caused by a read-only replica or secondary
// cluster operation.
var ErrReadOnly = errors.New("Cannot write to readonly storage")

// Storage is the way that logical backends are able read/write data.
type Storage interface {
List(prefix string) ([]string, error)
Expand Down
10 changes: 10 additions & 0 deletions logical/system_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ type SystemView interface {
// Returns true if caching is disabled. If true, no caches should be used,
// despite known slowdowns.
CachingDisabled() bool

// IsPrimary checks if this is a primary Vault instance. This
// can be used to avoid writes on secondaries and to avoid doing
// lazy upgrades which may cause writes.
IsPrimary() bool
}

type StaticSystemView struct {
Expand All @@ -38,6 +43,7 @@ type StaticSystemView struct {
SudoPrivilegeVal bool
TaintedVal bool
CachingDisabledVal bool
Primary bool
}

func (d StaticSystemView) DefaultLeaseTTL() time.Duration {
Expand All @@ -59,3 +65,7 @@ func (d StaticSystemView) Tainted() bool {
func (d StaticSystemView) CachingDisabled() bool {
return d.CachingDisabledVal
}

func (d StaticSystemView) IsPrimary() bool {
return d.Primary
}
13 changes: 10 additions & 3 deletions vault/barrier_view.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ import (
// BarrierView implements logical.Storage so it can be passed in as the
// durable storage mechanism for logical views.
type BarrierView struct {
barrier BarrierStorage
prefix string
barrier BarrierStorage
prefix string
readonly bool
}

// NewBarrierView takes an underlying security barrier and returns
Expand Down Expand Up @@ -68,6 +69,9 @@ func (v *BarrierView) Get(key string) (*logical.StorageEntry, error) {

// logical.Storage impl.
func (v *BarrierView) Put(entry *logical.StorageEntry) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(entry.Key); err != nil {
return err
}
Expand All @@ -80,6 +84,9 @@ func (v *BarrierView) Put(entry *logical.StorageEntry) error {

// logical.Storage impl.
func (v *BarrierView) Delete(key string) error {
if v.readonly {
return logical.ErrReadOnly
}
if err := v.sanityCheck(key); err != nil {
return err
}
Expand All @@ -89,7 +96,7 @@ func (v *BarrierView) Delete(key string) error {
// SubView constructs a nested sub-view using the given prefix
func (v *BarrierView) SubView(prefix string) *BarrierView {
sub := v.expandKey(prefix)
return &BarrierView{barrier: v.barrier, prefix: sub}
return &BarrierView{barrier: v.barrier, prefix: sub, readonly: v.readonly}
}

// expandKey is used to expand to the full key path with the prefix
Expand Down
32 changes: 32 additions & 0 deletions vault/barrier_view_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,35 @@ func TestBarrierView_ClearView(t *testing.T) {
t.Fatalf("have keys: %#v", out)
}
}
func TestBarrierView_Readonly(t *testing.T) {
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "foo/")

// Add a key before enabling read-only
entry := &Entry{Key: "test", Value: []byte("test")}
if err := view.Put(entry.Logical()); err != nil {
t.Fatalf("err: %v", err)
}

// Enable read only mode
view.readonly = true

// Put should fail in readonly mode
if err := view.Put(entry.Logical()); err != logical.ErrReadOnly {
t.Fatalf("err: %v", err)
}

// Delete nested
if err := view.Delete("test"); err != logical.ErrReadOnly {
t.Fatalf("err: %v", err)
}

// Check the non-nested key
e, err := view.Get("test")
if err != nil {
t.Fatalf("err: %v", err)
}
if e == nil {
t.Fatalf("key test missing")
}
}
8 changes: 8 additions & 0 deletions vault/dynamic_system_view_ext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// +build vault

package vault

// IsPrimary checks if this is a primary Vault instance.
func (d dynamicSystemView) IsPrimary() bool {
return true
}
5 changes: 4 additions & 1 deletion vault/rollback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
mounts := new(MountTable)
router := NewRouter()

_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")

mounts.Entries = []*MountEntry{
&MountEntry{
Path: "foo",
Expand All @@ -26,7 +29,7 @@ func mockRollback(t *testing.T) (*RollbackManager, *NoopBackend) {
if err != nil {
t.Fatal(err)
}
if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID}, nil); err != nil {
if err := router.Mount(backend, "foo", &MountEntry{UUID: meUUID}, view); err != nil {
t.Fatalf("err: %s", err)
}

Expand Down
41 changes: 36 additions & 5 deletions vault/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@ type Router struct {
l sync.RWMutex
root *radix.Tree
tokenStoreSalt *salt.Salt

// storagePrefix maps the prefix used for storage (ala the BarrierView)
// to the backend. This is used to map a key back into the backend that owns it.
// For example, logical/uuid1/foobar -> secrets/ (generic backend) + foobar
storagePrefix *radix.Tree
}

// NewRouter returns a new router
func NewRouter() *Router {
r := &Router{
root: radix.New(),
root: radix.New(),
storagePrefix: radix.New(),
}
return r
}
Expand Down Expand Up @@ -69,6 +75,7 @@ func (r *Router) Mount(backend logical.Backend, prefix string, mountEntry *Mount
loginPaths: pathsToRadix(paths.Unauthenticated),
}
r.root.Insert(prefix, re)
r.storagePrefix.Insert(storageView.prefix, re)

return nil
}
Expand All @@ -78,12 +85,19 @@ func (r *Router) Unmount(prefix string) error {
r.l.Lock()
defer r.l.Unlock()

// Call backend's Cleanup routine
re, ok := r.root.Get(prefix)
if ok {
re.(*routeEntry).backend.Cleanup()
// Fast-path out if the backend doesn't exist
raw, ok := r.root.Get(prefix)
if !ok {
return nil
}

// Call backend's Cleanup routine
re := raw.(*routeEntry)
re.backend.Cleanup()

// Purge from the radix trees
r.root.Delete(prefix)
r.storagePrefix.Delete(re.storageView.prefix)
return nil
}

Expand Down Expand Up @@ -182,6 +196,23 @@ func (r *Router) MatchingSystemView(path string) logical.SystemView {
return raw.(*routeEntry).backend.System()
}

// MatchingStoragePrefix returns the mount path matching and storage prefix
// matching the given path
func (r *Router) MatchingStoragePrefix(path string) (string, string, bool) {
r.l.RLock()
_, raw, ok := r.storagePrefix.LongestPrefix(path)
r.l.RUnlock()
if !ok {
return "", "", false
}

// Extract the mount path and storage prefix
re := raw.(*routeEntry)
mountPath := re.mountEntry.Path
prefix := re.storageView.prefix
return mountPath, prefix, true
}

// Route is used to route a given request
func (r *Router) Route(req *logical.Request) (*logical.Response, error) {
resp, _, _, err := r.routeCommon(req, false)
Expand Down
33 changes: 30 additions & 3 deletions vault/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ func (n *NoopBackend) Cleanup() {
// noop
}

func (n *NoopBackend) InvalidateKey(string) {
// noop
}

func TestRouter_Mount(t *testing.T) {
r := NewRouter()
_, barrier, _ := mockBarrier(t)
Expand All @@ -67,7 +71,7 @@ func TestRouter_Mount(t *testing.T) {
t.Fatal(err)
}
n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view)
err = r.Mount(n, "prod/aws/", &MountEntry{Path: "prod/aws/", UUID: meUUID}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
Expand Down Expand Up @@ -97,6 +101,14 @@ func TestRouter_Mount(t *testing.T) {
t.Fatalf("bad: %s", v)
}

mount, prefix, ok := r.MatchingStoragePrefix("logical/foo")
if !ok {
t.Fatalf("missing storage prefix")
}
if mount != "prod/aws/" || prefix != "logical/" {
t.Fatalf("Bad: %v - %v", mount, prefix)
}

req := &logical.Request{
Path: "prod/aws/foo",
}
Expand Down Expand Up @@ -124,7 +136,7 @@ func TestRouter_Unmount(t *testing.T) {
t.Fatal(err)
}
n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view)
err = r.Mount(n, "prod/aws/", &MountEntry{Path: "prod/aws/", UUID: meUUID}, view)
if err != nil {
t.Fatalf("err: %v", err)
}
Expand All @@ -141,6 +153,10 @@ func TestRouter_Unmount(t *testing.T) {
if !strings.Contains(err.Error(), "unsupported path") {
t.Fatalf("err: %v", err)
}

if _, _, ok := r.MatchingStoragePrefix("logical/foo"); ok {
t.Fatalf("should not have matching storage prefix")
}
}

func TestRouter_Remount(t *testing.T) {
Expand All @@ -153,11 +169,13 @@ func TestRouter_Remount(t *testing.T) {
t.Fatal(err)
}
n := &NoopBackend{}
err = r.Mount(n, "prod/aws/", &MountEntry{UUID: meUUID}, view)
me := &MountEntry{Path: "prod/aws/", UUID: meUUID}
err = r.Mount(n, "prod/aws/", me, view)
if err != nil {
t.Fatalf("err: %v", err)
}

me.Path = "stage/aws/"
err = r.Remount("prod/aws/", "stage/aws/")
if err != nil {
t.Fatalf("err: %v", err)
Expand Down Expand Up @@ -188,6 +206,15 @@ func TestRouter_Remount(t *testing.T) {
if len(n.Paths) != 1 || n.Paths[0] != "foo" {
t.Fatalf("bad: %v", n.Paths)
}

// Check the resolve from storage still works
mount, prefix, _ := r.MatchingStoragePrefix("logical/foobar")
if mount != "stage/aws/" {
t.Fatalf("bad mount: %s", mount)
}
if prefix != "logical/" {
t.Fatalf("Bad prefix: %s", prefix)
}
}

func TestRouter_RootPath(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions vault/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ func (n *rawHTTP) Cleanup() {
// noop
}

func (n *rawHTTP) InvalidateKey(string) {
// noop
}

func GenerateRandBytes(length int) ([]byte, error) {
if length < 0 {
return nil, fmt.Errorf("length must be >= 0")
Expand Down
6 changes: 4 additions & 2 deletions vault/token_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,11 +593,13 @@ func TestTokenStore_Revoke(t *testing.T) {
}

func TestTokenStore_Revoke_Leases(t *testing.T) {
_, ts, _, _ := TestCoreWithTokenStore(t)
c, ts, _, _ := TestCoreWithTokenStore(t)

view := NewBarrierView(c.barrier, "noop/")

// Mount a noop backend
noop := &NoopBackend{}
ts.expiration.router.Mount(noop, "", &MountEntry{UUID: ""}, nil)
ts.expiration.router.Mount(noop, "", &MountEntry{UUID: ""}, view)

ent := &TokenEntry{Path: "test", Policies: []string{"dev", "ops"}}
if err := ts.create(ent); err != nil {
Expand Down

0 comments on commit 745df0a

Please sign in to comment.