From 969439d828d43c6b02d1e261e521ad061a31eb9d Mon Sep 17 00:00:00 2001 From: Shivam Sandbhor Date: Tue, 5 Oct 2021 11:03:34 +0530 Subject: [PATCH 1/3] Fix IPv6 and timeout bug in nftables. Signed-off-by: Shivam Sandbhor --- nftables.go | 44 +++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/nftables.go b/nftables.go index c8722f3d..99c61b31 100644 --- a/nftables.go +++ b/nftables.go @@ -10,22 +10,22 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/google/nftables" "github.com/google/nftables/expr" - "golang.org/x/sys/unix" log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" ) const defaultTimeout = 4 * time.Hour type nft struct { - conn *nftables.Conn - conn6 *nftables.Conn - set *nftables.Set - set6 *nftables.Set - table *nftables.Table - table6 *nftables.Table - DenyAction string - DenyLog bool - DenyLogPrefix string + conn *nftables.Conn + conn6 *nftables.Conn + set *nftables.Set + set6 *nftables.Set + table *nftables.Table + table6 *nftables.Table + DenyAction string + DenyLog bool + DenyLogPrefix string BlacklistsIpv4 string BlacklistsIpv6 string } @@ -61,9 +61,10 @@ func (n *nft) Init() error { Priority: nftables.ChainPriorityFilter, }) set := &nftables.Set{ - Name: n.BlacklistsIpv4, - Table: n.table, - KeyType: nftables.TypeIPAddr, + Name: n.BlacklistsIpv4, + Table: n.table, + KeyType: nftables.TypeIPAddr, + HasTimeout: true, } if err := n.conn.AddSet(set, []nftables.SetElement{}); err != nil { @@ -129,9 +130,10 @@ func (n *nft) Init() error { Priority: nftables.ChainPriorityFilter, }) set := &nftables.Set{ - Name: n.BlacklistsIpv6, - Table: n.table6, - KeyType: nftables.TypeIP6Addr, + Name: n.BlacklistsIpv6, + Table: n.table6, + KeyType: nftables.TypeIP6Addr, + HasTimeout: true, } if err := n.conn6.AddSet(set, []nftables.SetElement{}); err != nil { @@ -194,7 +196,7 @@ func (n *nft) Add(decision *models.Decision) error { } if strings.Contains(*decision.Value, ":") { // ipv6 if n.conn6 != nil { - if err := n.conn.SetAddElements(n.set6, []nftables.SetElement{{Key: []byte(net.ParseIP(*decision.Value).To16()), Timeout: timeout}}); err != nil { + if err := n.conn6.SetAddElements(n.set6, []nftables.SetElement{{Key: []byte(net.ParseIP(*decision.Value).To16()), Timeout: timeout}}); err != nil { return err } if err := n.conn6.Flush(); err != nil { @@ -211,7 +213,7 @@ func (n *nft) Add(decision *models.Decision) error { } else { ipAddr = *decision.Value } - if err := n.conn.SetAddElements(n.set, []nftables.SetElement{{Key: []byte(net.ParseIP(ipAddr).To4())}}); err != nil { + if err := n.conn.SetAddElements(n.set, []nftables.SetElement{{Key: []byte(net.ParseIP(ipAddr).To4()), Timeout: timeout}}); err != nil { return err } if err := n.conn.Flush(); err != nil { @@ -225,14 +227,14 @@ func (n *nft) Add(decision *models.Decision) error { func (n *nft) Delete(decision *models.Decision) error { if strings.Contains(*decision.Value, ":") { // ipv6 if n.conn6 != nil { - if err := n.conn.SetDeleteElements(n.set, []nftables.SetElement{{Key: net.ParseIP(*decision.Value).To16()}}); err != nil { + if err := n.conn6.SetDeleteElements(n.set6, []nftables.SetElement{{Key: []byte(net.ParseIP(*decision.Value).To16())}}); err != nil { return err } - if err := n.conn.Flush(); err != nil { + if err := n.conn6.Flush(); err != nil { return err } } else { - log.Debugf("not adding '%s' because ipv6 is disabled", *decision.Value) + log.Debugf("not removing '%s' because ipv6 is disabled", *decision.Value) return nil } } else { // ipv4 From 93bd0ea499c7d0efa40f428098dc79b989f87f8f Mon Sep 17 00:00:00 2001 From: Shivam Sandbhor Date: Wed, 6 Oct 2021 15:40:46 +0530 Subject: [PATCH 2/3] Support ranges in nftables Signed-off-by: Shivam Sandbhor --- nftables.go | 112 ++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 100 insertions(+), 12 deletions(-) diff --git a/nftables.go b/nftables.go index 99c61b31..b4be88bf 100644 --- a/nftables.go +++ b/nftables.go @@ -3,6 +3,7 @@ package main import ( + "fmt" "net" "strings" "time" @@ -65,6 +66,7 @@ func (n *nft) Init() error { Table: n.table, KeyType: nftables.TypeIPAddr, HasTimeout: true, + Interval: true, } if err := n.conn.AddSet(set, []nftables.SetElement{}); err != nil { @@ -134,6 +136,7 @@ func (n *nft) Init() error { Table: n.table6, KeyType: nftables.TypeIP6Addr, HasTimeout: true, + Interval: true, } if err := n.conn6.AddSet(set, []nftables.SetElement{}); err != nil { @@ -194,9 +197,23 @@ func (n *nft) Add(decision *models.Decision) error { log.Errorf("unable to parse timeout '%s' for '%s' : %s", *decision.Duration, *decision.Value, err) timeout = defaultTimeout } + var cidr string if strings.Contains(*decision.Value, ":") { // ipv6 if n.conn6 != nil { - if err := n.conn6.SetAddElements(n.set6, []nftables.SetElement{{Key: []byte(net.ParseIP(*decision.Value).To16()), Timeout: timeout}}); err != nil { + if !strings.Contains(*decision.Value, "/") { + cidr = fmt.Sprintf("%s/128", *decision.Value) + } else { + cidr = *decision.Value + } + _, cidrNet, err := net.ParseCIDR(cidr) + if err != nil { + return err + } + if err := n.conn6.SetAddElements(n.set6, + []nftables.SetElement{ + {Key: []byte(cidrNet.IP.To16()), Timeout: timeout}, + {Key: []byte(incrementIP(BroadcastAddr(cidrNet)).To16()), IntervalEnd: true}, + }); err != nil { return err } if err := n.conn6.Flush(); err != nil { @@ -207,13 +224,20 @@ func (n *nft) Add(decision *models.Decision) error { return nil } } else { // ipv4 - var ipAddr string - if strings.Contains(*decision.Value, "/") { - ipAddr = strings.Split(*decision.Value, "/")[0] + if !strings.Contains(*decision.Value, "/") { + cidr = fmt.Sprintf("%s/32", *decision.Value) } else { - ipAddr = *decision.Value + cidr = *decision.Value } - if err := n.conn.SetAddElements(n.set, []nftables.SetElement{{Key: []byte(net.ParseIP(ipAddr).To4()), Timeout: timeout}}); err != nil { + _, cidrNet, err := net.ParseCIDR(cidr) + if err != nil { + return err + } + if err := n.conn.SetAddElements(n.set, + []nftables.SetElement{ + {Key: cidrNet.IP, Timeout: timeout}, + {Key: incrementIP(BroadcastAddr(cidrNet)), IntervalEnd: true}, + }); err != nil { return err } if err := n.conn.Flush(); err != nil { @@ -225,26 +249,49 @@ func (n *nft) Add(decision *models.Decision) error { } func (n *nft) Delete(decision *models.Decision) error { + var cidr string if strings.Contains(*decision.Value, ":") { // ipv6 if n.conn6 != nil { - if err := n.conn6.SetDeleteElements(n.set6, []nftables.SetElement{{Key: []byte(net.ParseIP(*decision.Value).To16())}}); err != nil { + if !strings.Contains(*decision.Value, "/") { + cidr = fmt.Sprintf("%s/128", *decision.Value) + } else { + cidr = *decision.Value + } + _, cidrNet, err := net.ParseCIDR(cidr) + if err != nil { + return err + } + if err := n.conn6.SetDeleteElements(n.set6, + []nftables.SetElement{ + {Key: []byte(cidrNet.IP.To16())}, + {Key: []byte(incrementIP(BroadcastAddr(cidrNet)).To16()), IntervalEnd: true}, + }); err != nil { return err } if err := n.conn6.Flush(); err != nil { return err } + } else { log.Debugf("not removing '%s' because ipv6 is disabled", *decision.Value) return nil } } else { // ipv4 - var ipAddr string - if strings.Contains(*decision.Value, "/") { - ipAddr = strings.Split(*decision.Value, "/")[0] + var cidr string + if !strings.Contains(*decision.Value, "/") { + cidr = fmt.Sprintf("%s/32", *decision.Value) } else { - ipAddr = *decision.Value + cidr = *decision.Value + } + _, cidrNet, err := net.ParseCIDR(cidr) + if err != nil { + return err } - if err := n.conn.SetDeleteElements(n.set, []nftables.SetElement{{Key: net.ParseIP(ipAddr).To4()}}); err != nil { + if err := n.conn.SetDeleteElements(n.set, + []nftables.SetElement{ + {Key: cidrNet.IP}, + {Key: incrementIP(BroadcastAddr(cidrNet)), IntervalEnd: true}, + }); err != nil { return err } if err := n.conn.Flush(); err != nil { @@ -271,3 +318,44 @@ func (n *nft) ShutDown() error { } return nil } + +// Utilites from https://github.com/IBM/netaddr/blob/master/net_utils.go + +// NewIP returns a new IP with the given size. The size must be 4 for IPv4 and +// 16 for IPv6. +func NewIP(size int) net.IP { + if size == 4 { + return net.ParseIP("0.0.0.0").To4() + } + if size == 16 { + return net.ParseIP("::") + } + panic("Bad value for size") +} + +// BroadcastAddr returns the last address in the given network, or the broadcast address. +func BroadcastAddr(n *net.IPNet) net.IP { + // The golang net package doesn't make it easy to calculate the broadcast address. :( + broadcast := NewIP(len(n.IP)) + for i := 0; i < len(n.IP); i++ { + broadcast[i] = n.IP[i] | ^n.Mask[i] + } + return broadcast +} + +// incrementIP returns the given IP + 1 +func incrementIP(ip net.IP) (result net.IP) { + result = make([]byte, len(ip)) // start off with a nice empty ip of proper length + + carry := true + for i := len(ip) - 1; i >= 0; i-- { + result[i] = ip[i] + if carry { + result[i]++ + if result[i] != 0 { + carry = false + } + } + } + return +} From cd21e91a45708778a8e0989bacda9169b680eb4d Mon Sep 17 00:00:00 2001 From: Shivam Sandbhor Date: Wed, 6 Oct 2021 15:55:06 +0530 Subject: [PATCH 3/3] Error out when NewIP is given bad size Signed-off-by: Shivam Sandbhor --- nftables.go | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/nftables.go b/nftables.go index b4be88bf..a261b8e5 100644 --- a/nftables.go +++ b/nftables.go @@ -209,10 +209,14 @@ func (n *nft) Add(decision *models.Decision) error { if err != nil { return err } + bca, err := BroadcastAddr(cidrNet) + if err != nil { + return err + } if err := n.conn6.SetAddElements(n.set6, []nftables.SetElement{ {Key: []byte(cidrNet.IP.To16()), Timeout: timeout}, - {Key: []byte(incrementIP(BroadcastAddr(cidrNet)).To16()), IntervalEnd: true}, + {Key: []byte(incrementIP(bca).To16()), IntervalEnd: true}, }); err != nil { return err } @@ -233,10 +237,14 @@ func (n *nft) Add(decision *models.Decision) error { if err != nil { return err } + bca, err := BroadcastAddr(cidrNet) + if err != nil { + return err + } if err := n.conn.SetAddElements(n.set, []nftables.SetElement{ {Key: cidrNet.IP, Timeout: timeout}, - {Key: incrementIP(BroadcastAddr(cidrNet)), IntervalEnd: true}, + {Key: incrementIP(bca), IntervalEnd: true}, }); err != nil { return err } @@ -261,10 +269,14 @@ func (n *nft) Delete(decision *models.Decision) error { if err != nil { return err } + bca, err := BroadcastAddr(cidrNet) + if err != nil { + return err + } if err := n.conn6.SetDeleteElements(n.set6, []nftables.SetElement{ {Key: []byte(cidrNet.IP.To16())}, - {Key: []byte(incrementIP(BroadcastAddr(cidrNet)).To16()), IntervalEnd: true}, + {Key: []byte(incrementIP(bca).To16()), IntervalEnd: true}, }); err != nil { return err } @@ -287,10 +299,14 @@ func (n *nft) Delete(decision *models.Decision) error { if err != nil { return err } + bca, err := BroadcastAddr(cidrNet) + if err != nil { + return err + } if err := n.conn.SetDeleteElements(n.set, []nftables.SetElement{ {Key: cidrNet.IP}, - {Key: incrementIP(BroadcastAddr(cidrNet)), IntervalEnd: true}, + {Key: incrementIP(bca), IntervalEnd: true}, }); err != nil { return err } @@ -323,24 +339,27 @@ func (n *nft) ShutDown() error { // NewIP returns a new IP with the given size. The size must be 4 for IPv4 and // 16 for IPv6. -func NewIP(size int) net.IP { +func NewIP(size int) (net.IP, error) { if size == 4 { - return net.ParseIP("0.0.0.0").To4() + return net.ParseIP("0.0.0.0").To4(), nil } if size == 16 { - return net.ParseIP("::") + return net.ParseIP("::"), nil } - panic("Bad value for size") + return net.IP{}, fmt.Errorf("invalid size %d", size) } // BroadcastAddr returns the last address in the given network, or the broadcast address. -func BroadcastAddr(n *net.IPNet) net.IP { +func BroadcastAddr(n *net.IPNet) (net.IP, error) { // The golang net package doesn't make it easy to calculate the broadcast address. :( - broadcast := NewIP(len(n.IP)) + broadcast, err := NewIP(len(n.IP)) + if err != nil { + return net.IP{}, err + } for i := 0; i < len(n.IP); i++ { broadcast[i] = n.IP[i] | ^n.Mask[i] } - return broadcast + return broadcast, nil } // incrementIP returns the given IP + 1