Skip to content

[client] Add reverse dns zone #3217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions client/internal/dns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package internal

import (
"fmt"
"net"
"slices"
"strings"

"github.com/miekg/dns"
log "github.com/sirupsen/logrus"

nbdns "github.com/netbirdio/netbird/dns"
)

func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData)
if ip == nil || ip.To4() == nil {
return nbdns.SimpleRecord{}, false
}

if !ipNet.Contains(ip) {
return nbdns.SimpleRecord{}, false
}

ipOctets := strings.Split(ip.String(), ".")
slices.Reverse(ipOctets)
rdnsName := dns.Fqdn(strings.Join(ipOctets, ".") + ".in-addr.arpa")

return nbdns.SimpleRecord{
Name: rdnsName,
Type: int(dns.TypePTR),
Class: aRecord.Class,
TTL: aRecord.TTL,
RData: dns.Fqdn(aRecord.Name),
}, true
}

// generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask)
maskOnes, _ := ipNet.Mask.Size()

// round up to nearest byte
octetsToUse := (maskOnes + 7) / 8

octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
}

reverseOctets := make([]string, octetsToUse)
for i := 0; i < octetsToUse; i++ {
reverseOctets[octetsToUse-1-i] = octets[i]
}

return dns.Fqdn(strings.Join(reverseOctets, ".") + ".in-addr.arpa"), nil
}

// zoneExists checks if a zone with the given name already exists in the configuration
func zoneExists(config *nbdns.Config, zoneName string) bool {
for _, zone := range config.CustomZones {
if zone.Domain == zoneName {
log.Debugf("reverse DNS zone %s already exists", zoneName)
return true
}
}
return false
}

// collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord

for _, zone := range config.CustomZones {
for _, record := range zone.Records {
if record.Type != int(dns.TypeA) {
continue
}

if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
records = append(records, ptrRecord)
}
}
}

return records
}

// addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
zoneName, err := generateReverseZoneName(ipNet)
if err != nil {
log.Warn(err)
return
}

if zoneExists(config, zoneName) {
log.Debugf("reverse DNS zone %s already exists", zoneName)
return
}

records := collectPTRRecords(config, ipNet)

reverseZone := nbdns.CustomZone{
Domain: zoneName,
Records: records,
}

config.CustomZones = append(config.CustomZones, reverseZone)
log.Debugf("added reverse DNS zone: %s with %d records", zoneName, len(records))
}
8 changes: 7 additions & 1 deletion client/internal/dns/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)

const (
ipv4ReverseZone = ".in-addr.arpa"
ipv6ReverseZone = ".ip6.arpa"
)

type hostManager interface {
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNS() error
Expand Down Expand Up @@ -94,9 +99,10 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
}

for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(customZone.Domain, "."),
MatchOnly: false,
MatchOnly: matchOnly,
})
}

Expand Down
10 changes: 6 additions & 4 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,11 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {

localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
return fmt.Errorf("local handler updater: %w", err)
}
upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
return fmt.Errorf("upstream handler updater: %w", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic

Expand Down Expand Up @@ -440,7 +440,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)

for _, customZone := range customZones {
if len(customZone.Records) == 0 {
return nil, nil, fmt.Errorf("received an empty list of records")
log.Warnf("received a custom zone with empty records, skipping domain: %s", customZone.Domain)
continue
}

muxUpdates = append(muxUpdates, handlerWrapper{
Expand All @@ -452,7 +453,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
for _, record := range customZone.Records {
var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass {
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
log.Warnf("received an invalid class type: %s", record.Class)
continue
}

key := buildRecordKey(record.Name, class, uint16(record.Type))
Expand Down
8 changes: 6 additions & 2 deletions client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Fail",
name: "Invalid Custom Zone Records list Should Skip",
initLocalMap: make(registrationMap),
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
Expand All @@ -285,7 +285,11 @@ func TestUpdateDNSServer(t *testing.T) {
},
},
},
shouldFail: true,
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).id(): handlerWrapper{
domain: ".",
handler: dummyHandler,
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
Expand Down
11 changes: 8 additions & 3 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{}
}

if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}

Expand Down Expand Up @@ -1021,7 +1021,7 @@ func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
return dnsRoutes
}

func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0),
Expand Down Expand Up @@ -1061,6 +1061,11 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
}
dnsUpdate.NameServerGroups = append(dnsUpdate.NameServerGroups, dnsNSGroup)
}

if len(dnsUpdate.CustomZones) > 0 {
addReverseZone(&dnsUpdate, network)
}

return dnsUpdate
}

Expand Down Expand Up @@ -1367,7 +1372,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
return nil, nil, err
}
routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
return routes, &dnsCfg, nil
}

Expand Down
15 changes: 15 additions & 0 deletions client/internal/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemovePeerFunc: func(peerKey string) error {
return nil
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Expand Down Expand Up @@ -692,6 +701,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
},
},
},
{
Domain: "0.66.100.in-addr.arpa.",
},
},
NameServerGroups: []*mgmtProto.NameServerGroup{
{
Expand Down Expand Up @@ -721,6 +733,9 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
},
},
},
{
Domain: "0.66.100.in-addr.arpa.",
},
},
expectedNSGroupsLen: 1,
expectedNSGroups: []*nbdns.NameServerGroup{
Expand Down
Loading