Skip to content

Commit

Permalink
[client] Use GPO DNS Policy Config to configure DNS if present (#3319)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal authored Feb 13, 2025
1 parent a930c2a commit c4a6daf
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 87 deletions.
220 changes: 136 additions & 84 deletions client/internal/dns/host_windows.go
Original file line number Diff line number Diff line change
@@ -1,48 +1,72 @@
package dns

import (
"errors"
"fmt"
"io"
"strings"
"syscall"

"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"

nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager"
)

var (
userenv = syscall.NewLazyDLL("userenv.dll")

// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
)

const (
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient`
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\DnsPolicyConfig\NetBird-Match`

dnsPolicyConfigVersionKey = "Version"
dnsPolicyConfigVersionValue = 2
dnsPolicyConfigNameKey = "Name"
dnsPolicyConfigGenericDNSServersKey = "GenericDNSServers"
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8
)

const (
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"

// RP_FORCE: Reapply all policies even if no policy change was detected
rpForce = 0x1
)

type registryConfigurator struct {
guid string
routingAll bool
gpo bool
}

func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil {
return nil, err
}
return newHostManagerWithGuid(guid)
}

func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
var useGPO bool
k, err := registry.OpenKey(registry.LOCAL_MACHINE, gpoDnsPolicyRoot, registry.QUERY_VALUE)
if err != nil {
log.Debugf("failed to open GPO DNS policy root: %v", err)
} else {
closer(k)
useGPO = true
log.Infof("detected GPO DNS policy configuration, using policy store")
}

return &registryConfigurator{
guid: guid,
gpo: useGPO,
}, nil
}

Expand All @@ -51,30 +75,23 @@ func (r *registryConfigurator) supportCustomPort() bool {
}

func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if config.RouteAll {
err = r.addDNSSetupForAll(config.ServerIP)
if err != nil {
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
} else if r.routingAll {
err = r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey)
if err != nil {
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
return fmt.Errorf("delete interface registry key property: %w", err)
}
r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
}

if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}

var (
searchDomains []string
matchDomains []string
)

var searchDomains, matchDomains []string
for _, dConf := range config.Domains {
if dConf.Disabled {
continue
Expand All @@ -86,91 +103,80 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
}

if len(matchDomains) != 0 {
err = r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil {
return fmt.Errorf("add dns match policy: %w", err)
}
} else {
err = removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath)
}
if err != nil {
return fmt.Errorf("add dns match policy: %w", err)
if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err)
}
}

err = r.updateSearchDomains(searchDomains)
if err != nil {
if err := r.updateSearchDomains(searchDomains); err != nil {
return fmt.Errorf("update search domains: %w", err)
}

return nil
}

func (r *registryConfigurator) addDNSSetupForAll(ip string) error {
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip)
if err != nil {
return fmt.Errorf("adding dns setup for all failed with error: %w", err)
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip); err != nil {
return fmt.Errorf("adding dns setup for all failed: %w", err)
}
r.routingAll = true
log.Infof("configured %s:53 as main DNS forwarder for this peer", ip)
return nil
}

func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) error {
_, err := registry.OpenKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.QUERY_VALUE)
if err == nil {
err = registry.DeleteKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath)
if err != nil {
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
}
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
policyPath := dnsPolicyConfigMatchPath
if r.gpo {
policyPath = gpoDnsPolicyConfigMatchPath
}

regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsPolicyConfigMatchPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("unable to create registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", dnsPolicyConfigMatchPath, err)
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
return fmt.Errorf("remove existing dns policy: %w", err)
}

err = regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue)
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigVersionKey, err)
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
}
defer closer(regKey)

err = regKey.SetStringsValue(dnsPolicyConfigNameKey, domains)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigNameKey, err)
if err := regKey.SetDWordValue(dnsPolicyConfigVersionKey, dnsPolicyConfigVersionValue); err != nil {
return fmt.Errorf("set %s: %w", dnsPolicyConfigVersionKey, err)
}

err = regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigGenericDNSServersKey, err)
if err := regKey.SetStringsValue(dnsPolicyConfigNameKey, domains); err != nil {
return fmt.Errorf("set %s: %w", dnsPolicyConfigNameKey, err)
}

err = regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue)
if err != nil {
return fmt.Errorf("unable to set registry value for %s, error: %w", dnsPolicyConfigConfigOptionsKey, err)
if err := regKey.SetStringValue(dnsPolicyConfigGenericDNSServersKey, ip); err != nil {
return fmt.Errorf("set %s: %w", dnsPolicyConfigGenericDNSServersKey, err)
}

log.Infof("added %d match domains to the state. Domain list: %s", len(domains), domains)

return nil
}

func (r *registryConfigurator) restoreHostDNS() error {
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
log.Errorf("remove registry key from dns policy config: %s", err)
if err := regKey.SetDWordValue(dnsPolicyConfigConfigOptionsKey, dnsPolicyConfigConfigOptionsValue); err != nil {
return fmt.Errorf("set %s: %w", dnsPolicyConfigConfigOptionsKey, err)
}

if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
return fmt.Errorf("remove interface registry key: %w", err)
if r.gpo {
if err := refreshGroupPolicy(); err != nil {
log.Warnf("failed to refresh group policy: %v", err)
}
}

log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
return nil
}

func (r *registryConfigurator) updateSearchDomains(domains []string) error {
err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ","))
if err != nil {
return fmt.Errorf("adding search domain failed with error: %w", err)
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
return fmt.Errorf("update search domains: %w", err)
}

log.Infof("updated the search domains in the registry with %d domains. Domain list: %s", len(domains), domains)

log.Infof("updated search domains: %s", domains)
return nil
}

Expand All @@ -181,11 +187,9 @@ func (r *registryConfigurator) setInterfaceRegistryKeyStringValue(key, value str
}
defer closer(regKey)

err = regKey.SetStringValue(key, value)
if err != nil {
return fmt.Errorf("applying key %s with value \"%s\" for interface failed with error: %w", key, value, err)
if err := regKey.SetStringValue(key, value); err != nil {
return fmt.Errorf("set key %s=%s: %w", key, value, err)
}

return nil
}

Expand All @@ -196,43 +200,91 @@ func (r *registryConfigurator) deleteInterfaceRegistryKeyProperty(propertyKey st
}
defer closer(regKey)

err = regKey.DeleteValue(propertyKey)
if err != nil {
return fmt.Errorf("deleting registry key %s for interface failed with error: %w", propertyKey, err)
if err := regKey.DeleteValue(propertyKey); err != nil {
return fmt.Errorf("delete registry key %s: %w", propertyKey, err)
}

return nil
}

func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
var regKey registry.Key

regKeyPath := interfaceConfigPath + "\\" + r.guid

regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.SET_VALUE)
if err != nil {
return regKey, fmt.Errorf("unable to open the interface registry key, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
return regKey, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
}

return regKey, nil
}

func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via registry: %w", err)
func (r *registryConfigurator) restoreHostDNS() error {
if err := r.removeDNSMatchPolicies(); err != nil {
log.Errorf("remove dns match policies: %s", err)
}

if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigSearchListKey); err != nil {
return fmt.Errorf("remove interface registry key: %w", err)
}

return nil
}

func (r *registryConfigurator) removeDNSMatchPolicies() error {
var merr *multierror.Error
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err))
}

if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err))
}

if err := refreshGroupPolicy(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("refresh group policy: %w", err))
}

return nberrors.FormatErrorOrNil(merr)
}

func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
return r.restoreHostDNS()
}

func removeRegistryKeyFromDNSPolicyConfig(regKeyPath string) error {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, registry.QUERY_VALUE)
if err == nil {
defer closer(k)
err = registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath)
if err != nil {
return fmt.Errorf("unable to remove existing key from registry, key: HKEY_LOCAL_MACHINE\\%s, error: %w", regKeyPath, err)
if err != nil {
log.Debugf("failed to open HKEY_LOCAL_MACHINE\\%s: %v", regKeyPath, err)
return nil
}

closer(k)
if err := registry.DeleteKey(registry.LOCAL_MACHINE, regKeyPath); err != nil {
return fmt.Errorf("delete HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
}

return nil
}

func refreshGroupPolicy() error {
// refreshPolicyExFn.Call() panics if the func is not found
defer func() {
if r := recover(); r != nil {
log.Errorf("Recovered from panic: %v", r)
}
}()

ret, _, err := refreshPolicyExFn.Call(
// bMachine = TRUE (computer policy)
uintptr(1),
// dwOptions = RP_FORCE
uintptr(rpForce),
)

if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("RefreshPolicyEx failed: %w", err)
}
return fmt.Errorf("RefreshPolicyEx failed")
}

return nil
}

Expand Down
7 changes: 4 additions & 3 deletions client/internal/dns/unclean_shutdown_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ import (

type ShutdownState struct {
Guid string
GPO bool
}

func (s *ShutdownState) Name() string {
return "dns_state"
}

func (s *ShutdownState) Cleanup() error {
manager, err := newHostManagerWithGuid(s.Guid)
if err != nil {
return fmt.Errorf("create host manager: %w", err)
manager := &registryConfigurator{
guid: s.Guid,
gpo: s.GPO,
}

if err := manager.restoreUncleanShutdownDNS(); err != nil {
Expand Down

0 comments on commit c4a6daf

Please sign in to comment.