Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group based query protection #1337

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 86 additions & 16 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ type LightHouse struct {

calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote

metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *logrus.Logger
metrics *MessageMetrics
metricHolepunchTx metrics.Counter
l *logrus.Logger
queryProtectionTable lightHouseQueryProtectionTable
}

// NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object
Expand All @@ -95,17 +96,21 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
nebulaPort = uint32(uPort.Port())
}

queryProtection := newQueryProtectionTableFromConfig(c)
l.Infof("Loaded rules %v", queryProtection)

h := LightHouse{
ctx: ctx,
amLighthouse: amLighthouse,
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
addrMap: make(map[netip.Addr]*RemoteList),
nebulaPort: nebulaPort,
punchConn: pc,
punchy: p,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
ctx: ctx,
amLighthouse: amLighthouse,
myVpnNetworks: cs.myVpnNetworks,
myVpnNetworksTable: cs.myVpnNetworksTable,
addrMap: make(map[netip.Addr]*RemoteList),
nebulaPort: nebulaPort,
punchConn: pc,
punchy: p,
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
l: l,
queryProtectionTable: queryProtection,
}
lighthouses := make(map[netip.Addr]struct{})
h.lighthouses.Store(&lighthouses)
Expand Down Expand Up @@ -1010,7 +1015,10 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
return lhh.meta
}

func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) {
func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, hostInfo *HostInfo, p []byte, w EncWriter) {

fromVpnAddrs := hostInfo.vpnAddrs

n := lhh.resetMeta()
err := n.Unmarshal(p)
if err != nil {
Expand All @@ -1029,7 +1037,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [

switch n.Type {
case NebulaMeta_HostQuery:
lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w)
lhh.handleHostQuery(n, hostInfo, rAddr, w)

case NebulaMeta_HostQueryReply:
lhh.handleHostQueryReply(n, fromVpnAddrs)
Expand All @@ -1046,7 +1054,7 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs [
}
}

func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) {
func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, hostInfo *HostInfo, addr netip.AddrPort, w EncWriter) {
// Exit if we don't answer queries
if !lhh.lh.amLighthouse {
if lhh.l.Level >= logrus.DebugLevel {
Expand All @@ -1055,6 +1063,9 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
return
}

// Get the from addrs back
fromVpnAddrs := hostInfo.vpnAddrs

useVersion := cert.Version1
var queryVpnAddr netip.Addr
if n.Details.OldVpnAddr != 0 {
Expand All @@ -1072,6 +1083,13 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
return
}

queriedHostInvertedGroups := lhh.lh.ifce.GetHostInfo(queryVpnAddr).ConnectionState.peerCert.InvertedGroups

if !lhh.lh.queryProtectionTable.check(hostInfo.ConnectionState.peerCert.InvertedGroups, queriedHostInvertedGroups, lhh.l) {
lhh.l.Debugln("Dropping HostQuery from", fromVpnAddrs, "for", queryVpnAddr, "due to query protection")
return
}

found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnAddr, func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostQueryReply
Expand Down Expand Up @@ -1448,3 +1466,55 @@ func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr,
}
return netip.Addr{}, false
}

type lightHouseQueryProtectionTable interface {
check(invertedGroups map[string]struct{}, queriedHostInvertedGroups map[string]struct{}, logger *logrus.Logger) bool
}
type QueryProtectionTable struct {
rules map[string][]string
}

func newQueryProtectionTableFromConfig(c *config.C) *QueryProtectionTable {
rawRules := c.GetMap("lighthouse.query_protection", map[interface{}]interface{}{})

rules := make(map[string][]string)
for k, v := range rawRules {
var subgroups []string
for _, s := range v.([]interface{}) {
subgroups = append(subgroups, s.(string))
}
rules[k.(string)] = subgroups
}

return &QueryProtectionTable{
rules: rules,
}
}

func (l *QueryProtectionTable) check(invertedGroups map[string]struct{}, queriedHostInvertedGroups map[string]struct{}, logger *logrus.Logger) bool {

if len(l.rules) == 0 {
return true
}

for group := range invertedGroups {
if allowedGroups, ok := l.rules[group]; ok {
for _, allowedGroup := range allowedGroups {
if _, ok := queriedHostInvertedGroups[allowedGroup]; ok {
return true
}
}
}
}

logger.Debugf("Dropped query due to query protection: %s : %s", invertedGroups, queriedHostInvertedGroups)
return false
}

type mockQueryProtectionTable struct {
}

func (m *mockQueryProtectionTable) check(invertedGroups map[string]struct{}, queriedHostInvertedGroups map[string]struct{}, logger *logrus.Logger) bool {
// It's all good, man
return true
}
66 changes: 63 additions & 3 deletions lighthouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {

mw := &mockEncWriter{}

hi := []netip.Addr{vpnIp2}
hi := &HostInfo{
vpnAddrs: []netip.Addr{vpnIp2},
}
b.Run("notfound", func(b *testing.B) {
lhh := lh.NewRequestHandler()
req := &NebulaMeta{
Expand Down Expand Up @@ -205,6 +207,7 @@ func TestLighthouse_Memory(t *testing.T) {
}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh.ifce = &mockEncWriter{}
lh.queryProtectionTable = &mockQueryProtectionTable{}
assert.NoError(t, err)
lhh := lh.NewRequestHandler()

Expand Down Expand Up @@ -327,7 +330,13 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l
w := &testEncWriter{
metaFilter: &filter,
}
lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w)
hi := &HostInfo{
vpnAddrs: []netip.Addr{myVpnIp},
ConnectionState: &ConnectionState{
peerCert: &cert.CachedCertificate{InvertedGroups: map[string]struct{}{}},
},
}
lhh.HandleRequest(fromAddr, hi, b, w)
return w.lastReply
}

Expand Down Expand Up @@ -357,8 +366,12 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad
panic(err)
}

hi := &HostInfo{
vpnAddrs: []netip.Addr{vpnIp},
}

w := &testEncWriter{}
lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w)
lhh.HandleRequest(fromAddr, hi, b, w)
}

type testLhReply struct {
Expand Down Expand Up @@ -494,3 +507,50 @@ func Test_findNetworkUnion(t *testing.T) {
out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
assert.False(t, ok)
}

func TestQueryProtectionTable(t *testing.T) {
l := test.NewLogger()

qpt := QueryProtectionTable{
rules: map[string][]string{
"group1": {"allowed1", "allowed2"},
"group2": {"allowed3"},
},
}

invertedGroups := map[string]struct{}{
"group1": {},
}
queriedHostInvertedGroups := map[string]struct{}{
"allowed1": {},
}

assert.True(t, qpt.check(invertedGroups, queriedHostInvertedGroups, l))

queriedHostInvertedGroups = map[string]struct{}{
"notAllowed": {},
}

assert.False(t, qpt.check(invertedGroups, queriedHostInvertedGroups, l))

invertedGroups = map[string]struct{}{
"group2": {},
}
queriedHostInvertedGroups = map[string]struct{}{
"allowed3": {},
}

assert.True(t, qpt.check(invertedGroups, queriedHostInvertedGroups, l))

invertedGroups = map[string]struct{}{
"group3": {},
}
queriedHostInvertedGroups = map[string]struct{}{
"allowed1": {},
}

assert.False(t, qpt.check(invertedGroups, queriedHostInvertedGroups, l))

qpt.rules = map[string][]string{}
assert.True(t, qpt.check(invertedGroups, queriedHostInvertedGroups, l))
}
2 changes: 1 addition & 1 deletion outside.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
return
}

lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
lhf.HandleRequest(ip, hostinfo, d, f)

// Fallthrough to the bottom to record incoming traffic

Expand Down
Loading