From 35176df49073b0b37eaea3ece33964453392d25b Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Wed, 19 Feb 2025 10:35:57 +0800 Subject: [PATCH 01/11] refector dns resolver Signed-off-by: LiZhenCheng9527 --- pkg/controller/controller.go | 13 +- pkg/dns/ads_handler.go | 263 +++++++++++++++++++++++++++++ pkg/dns/ads_handler_test.go | 124 ++++++++++++++ pkg/dns/dns.go | 294 +++++++++----------------------- pkg/dns/dns_test.go | 318 +++++++++++++---------------------- 5 files changed, 593 insertions(+), 419 deletions(-) create mode 100644 pkg/dns/ads_handler.go create mode 100644 pkg/dns/ads_handler_test.go diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 8ded81bac..c5be6eca5 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -31,7 +31,6 @@ import ( "kmesh.net/kmesh/pkg/controller/encryption/ipsec" manage "kmesh.net/kmesh/pkg/controller/manage" "kmesh.net/kmesh/pkg/controller/security" - "kmesh.net/kmesh/pkg/dns" "kmesh.net/kmesh/pkg/kolog" "kmesh.net/kmesh/pkg/kube" "kmesh.net/kmesh/pkg/logger" @@ -156,12 +155,12 @@ func (c *Controller) Start(stopCh <-chan struct{}) error { } if c.client.AdsController != nil { - dnsResolver, err := dns.NewDNSResolver(c.client.AdsController.Processor.Cache) - if err != nil { - return fmt.Errorf("dns resolver create failed: %v", err) - } - dnsResolver.StartDNSResolver(stopCh) - c.client.AdsController.Processor.DnsResolverChan = dnsResolver.DnsResolverChan + // dnsResolver, err := dns.NewDNSResolver(c.client.AdsController.Processor.Cache) + // if err != nil { + // return fmt.Errorf("dns resolver create failed: %v", err) + // } + // dnsResolver.StartDNSResolver(stopCh) + // c.client.AdsController.Processor.DnsResolverChan = dnsResolver.DnsResolverChan } return c.client.Run(stopCh) diff --git a/pkg/dns/ads_handler.go b/pkg/dns/ads_handler.go new file mode 100644 index 000000000..9a41d6205 --- /dev/null +++ b/pkg/dns/ads_handler.go @@ -0,0 +1,263 @@ +/* + * Copyright The Kmesh Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dns + +import ( + "net" + "net/netip" + "slices" + + clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + "google.golang.org/protobuf/types/known/wrapperspb" + + core_v2 "kmesh.net/kmesh/api/v2/core" + "kmesh.net/kmesh/pkg/controller/ads" +) + +// adsDnsResolver is DNS resolver of Kernel Native +type AdsDnsResolver struct { + Clusters chan []*clusterv3.Cluster + adsCache *ads.AdsCache + DnsResolver *DNSResolver +} + +func NewAdsDnsResolver(adsCache *ads.AdsCache) (*AdsDnsResolver, error) { + resolver, err := NewDNSResolver() + if err != nil { + return nil, err + } + return &AdsDnsResolver{ + Clusters: make(chan []*clusterv3.Cluster), + adsCache: adsCache, + DnsResolver: resolver, + }, nil +} + +func (adsResolver *AdsDnsResolver) StartAdsDnsResolver(stopCh <-chan struct{}) { + go adsResolver.startAdsResolver() + go adsResolver.refreshAdsWorker() + go func() { + <-stopCh + adsResolver.DnsResolver.dnsRefreshQueue.ShutDown() + close(adsResolver.Clusters) + }() +} + +func (adsResolver *AdsDnsResolver) startAdsResolver() { + rateLimiter := make(chan struct{}, MaxConcurrency) + for clusters := range adsResolver.Clusters { + rateLimiter <- struct{}{} + go func(clusters []*clusterv3.Cluster) { + defer func() { + <-rateLimiter + }() + adsResolver.DnsResolver.resolveDomains(clusters) + }(clusters) + } +} + +func (adsResolver *AdsDnsResolver) refreshAdsDns() bool { + element, quit := adsResolver.DnsResolver.dnsRefreshQueue.Get() + if quit { + return false + } + defer adsResolver.DnsResolver.dnsRefreshQueue.Done(element) + e := element.(*pendingResolveDomain) + adsResolver.DnsResolver.RLock() + _, exist := adsResolver.DnsResolver.cache[e.domainName] + adsResolver.DnsResolver.RUnlock() + // if the domain is no longer watched, no need to refresh it + if !exist { + return true + } + // addrs, _, err := adsResolver.DnsResolver.resolve(e) + // if err != nil { + // log.Errorf("dns error is: %v", err) + // } + // entry.addresses = addrs + adsResolver.DnsResolver.resolve(e) + adsResolver.adsCache.ClusterCache.Flush() + return true +} + +func (adsResolver *AdsDnsResolver) refreshAdsWorker() { + for adsResolver.refreshAdsDns() { + } +} + +// resolveDomains takes a slice of cluster +func (r *DNSResolver) resolveDomains(clusters []*clusterv3.Cluster) { + domains := getPendingResolveDomain(clusters) + + // Stow domain updates, need to remove unwatched domains first + r.removeUnwatchedDomain(domains) + for _, v := range domains { + r.Lock() + if r.cache[v.domainName] == nil { + r.cache[v.domainName] = &domainCacheEntry{} + } + r.Unlock() + r.dnsRefreshQueue.AddAfter(v, 0) + } +} + +func (adsResolver *AdsDnsResolver) adsDnsResolve(domain *pendingResolveDomain) { + adsResolver.DnsResolver.RLock() + entry := adsResolver.DnsResolver.cache[domain.domainName] + // This can happen when the domain is deleted before the refresher tick reaches + if entry == nil { + adsResolver.DnsResolver.RUnlock() + log.Errorf("domain is not in cache") + return + } + + adsResolver.DnsResolver.RUnlock() + + addrs, ttl, err := adsResolver.DnsResolver.doResolve(domain.domainName, domain.refreshRate) + if err != nil { + log.Errorf("failed to resolve: %v, err: %v", domain.domainName, err) + } + + // for the newly resolved domain just push to bpf map + log.Infof("resolve dns name: %s, addr: %v", domain.domainName, addrs) + // refresh the dns address periodically by respecting the dnsRefreshRate and ttl, which one is shorter + if ttl > domain.refreshRate { + ttl = domain.refreshRate + } + if ttl == 0 { + ttl = DeRefreshInterval + } + if !slices.Equal(entry.addresses, addrs) { + for _, c := range domain.clusters { + ready := overwriteDnsCluster(c, domain.domainName, addrs) + if ready { + if !adsResolver.adsCache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, c) { + log.Debugf("cluster: %s is deleted", c.Name) + return + } + } + } + } + adsResolver.DnsResolver.Lock() + entry.addresses = addrs + adsResolver.DnsResolver.Unlock() +} + +func overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []string) bool { + buildLbEndpoints := func(port uint32) []*endpointv3.LbEndpoint { + lbEndpoints := make([]*endpointv3.LbEndpoint, 0, len(addrs)) + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip == nil { + continue + } + if ip.To4() == nil { + continue + } + lbEndpoint := &endpointv3.LbEndpoint{ + HealthStatus: v3.HealthStatus_HEALTHY, + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: addr, + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: port, + }, + }, + }, + }, + }, + }, + // TODO: support LoadBalancingWeight + LoadBalancingWeight: &wrapperspb.UInt32Value{ + Value: 1, + }, + } + lbEndpoints = append(lbEndpoints, lbEndpoint) + } + return lbEndpoints + } + + ready := true + for _, e := range cluster.LoadAssignment.Endpoints { + pos := -1 + var lbEndpoints []*endpointv3.LbEndpoint + for i, le := range e.LbEndpoints { + socketAddr, ok := le.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) + if !ok { + continue + } + _, err := netip.ParseAddr(socketAddr.SocketAddress.Address) + if err != nil { + if socketAddr.SocketAddress.Address == domain { + pos = i + lbEndpoints = buildLbEndpoints(socketAddr.SocketAddress.GetPortValue()) + } else { + // There is other domains not resolved for this cluster + ready = false + } + } + } + if pos >= 0 { + e.LbEndpoints = slices.Replace(e.LbEndpoints, pos, pos+1, lbEndpoints...) + } + } + + return ready +} + +// Get domain name and refreshrate from cluster, and also store cluster and port in the return addresses for later use +func getPendingResolveDomain(clusters []*clusterv3.Cluster) map[string]*pendingResolveDomain { + domains := make(map[string]*pendingResolveDomain) + + for _, cluster := range clusters { + if cluster.LoadAssignment == nil { + continue + } + + for _, e := range cluster.LoadAssignment.Endpoints { + for _, le := range e.LbEndpoints { + socketAddr, ok := le.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) + if !ok { + continue + } + address := socketAddr.SocketAddress.Address + if _, err := netip.ParseAddr(address); err == nil { + // This is an ip address + continue + } + + if v, ok := domains[address]; ok { + v.clusters = append(v.clusters, cluster) + } else { + domainWithRefreshRate := &pendingResolveDomain{ + domainName: address, + clusters: []*clusterv3.Cluster{cluster}, + refreshRate: cluster.GetDnsRefreshRate().AsDuration(), + } + domains[address] = domainWithRefreshRate + } + } + } + } + + return domains +} diff --git a/pkg/dns/ads_handler_test.go b/pkg/dns/ads_handler_test.go new file mode 100644 index 000000000..c9faa1558 --- /dev/null +++ b/pkg/dns/ads_handler_test.go @@ -0,0 +1,124 @@ +/* + * Copyright The Kmesh Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dns + +import ( + "math/rand" + "slices" + "sync" + "testing" + + clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + "google.golang.org/protobuf/types/known/wrapperspb" + core_v2 "kmesh.net/kmesh/api/v2/core" + "kmesh.net/kmesh/pkg/controller/ads" +) + +type fakeAdsDnsServer struct { +} + +func TestOverwriteDNSCluster(t *testing.T) { + domain := "www.google.com" + addrs := []string{"10.1.1.1", "10.1.1.2"} + cluster := &clusterv3.Cluster{ + Name: "ut-cluster", + ClusterDiscoveryType: &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_LOGICAL_DNS, + }, + LoadAssignment: &endpointv3.ClusterLoadAssignment{ + ClusterName: "ut-cluster", + Endpoints: []*endpointv3.LocalityLbEndpoints{ + { + LoadBalancingWeight: wrapperspb.UInt32(30), + Priority: uint32(15), + LbEndpoints: []*endpointv3.LbEndpoint{ + { + HealthStatus: v3.HealthStatus_HEALTHY, + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: domain, + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: uint32(9898), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + overwriteDnsCluster(cluster, domain, addrs) + + endpoints := cluster.GetLoadAssignment().GetEndpoints()[0].GetLbEndpoints() + if len(endpoints) != 2 { + t.Errorf("Expected 2 LbEndpoints, but got %d", len(endpoints)) + } + out := []string{} + for _, e := range endpoints { + socketAddr, ok := e.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) + if !ok { + continue + } + address := socketAddr.SocketAddress.Address + out = append(out, address) + } + if !slices.Equal(out, addrs) { + t.Errorf("OverwriteDNSCluster error, expected %v, but got %v", out, addrs) + } +} + +// This test aims to evaluate the concurrent writing behavior of the adsCache by utilizing the test race feature. +// The test verifies the ability of the adsCache to handle concurrent access and updates correctly in a multi-goroutine environment. +func TestADSCacheConcurrentWriting(t *testing.T) { + adsCache := ads.NewAdsCache(nil) + cluster := &clusterv3.Cluster{ + Name: "ut-cluster", + ClusterDiscoveryType: &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_LOGICAL_DNS, + }, + } + adsCache.CreateApiClusterByCds(core_v2.ApiStatus_NONE, cluster) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + currentStatus := adsCache.GetApiClusterStatus(cluster.GetName()) + newStatus := currentStatus + core_v2.ApiStatus(rand.Intn(3)-1) + if rand.Intn(2) == 0 { + adsCache.UpdateApiClusterIfExists(newStatus, cluster) + } else { + adsCache.UpdateApiClusterStatus(cluster.GetName(), newStatus) + } + } + }() + } + + wg.Wait() +} diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index dfbe4fd6d..4561504cc 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -19,22 +19,15 @@ package dns import ( "fmt" "net" - "net/netip" - "slices" "sort" "sync" "time" clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" - v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" - "google.golang.org/protobuf/types/known/wrapperspb" "github.com/miekg/dns" "k8s.io/client-go/util/workqueue" - core_v2 "kmesh.net/kmesh/api/v2/core" - "kmesh.net/kmesh/pkg/controller/ads" "kmesh.net/kmesh/pkg/logger" ) @@ -49,12 +42,12 @@ const ( ) type DNSResolver struct { - DnsResolverChan chan []*clusterv3.Cluster + // DnsResolverChan chan []*clusterv3.Cluster client *dns.Client resolvConfServers []string cache map[string]*domainCacheEntry // adsCache is used for update bpf map - adsCache *ads.AdsCache + // adsCache *ads.AdsCache // dns refresh priority queue based on exp dnsRefreshQueue workqueue.TypedDelayingInterface[any] sync.RWMutex @@ -73,76 +66,11 @@ type pendingResolveDomain struct { refreshRate time.Duration } -func overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []string) bool { - buildLbEndpoints := func(port uint32) []*endpointv3.LbEndpoint { - lbEndpoints := make([]*endpointv3.LbEndpoint, 0, len(addrs)) - for _, addr := range addrs { - ip := net.ParseIP(addr) - if ip == nil { - continue - } - if ip.To4() == nil { - continue - } - lbEndpoint := &endpointv3.LbEndpoint{ - HealthStatus: v3.HealthStatus_HEALTHY, - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: addr, - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: port, - }, - }, - }, - }, - }, - }, - // TODO: support LoadBalancingWeight - LoadBalancingWeight: &wrapperspb.UInt32Value{ - Value: 1, - }, - } - lbEndpoints = append(lbEndpoints, lbEndpoint) - } - return lbEndpoints - } - - ready := true - for _, e := range cluster.LoadAssignment.Endpoints { - pos := -1 - var lbEndpoints []*endpointv3.LbEndpoint - for i, le := range e.LbEndpoints { - socketAddr, ok := le.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) - if !ok { - continue - } - _, err := netip.ParseAddr(socketAddr.SocketAddress.Address) - if err != nil { - if socketAddr.SocketAddress.Address == domain { - pos = i - lbEndpoints = buildLbEndpoints(socketAddr.SocketAddress.GetPortValue()) - } else { - // There is other domains not resolved for this cluster - ready = false - } - } - } - if pos >= 0 { - e.LbEndpoints = slices.Replace(e.LbEndpoints, pos, pos+1, lbEndpoints...) - } - } - - return ready -} - -func NewDNSResolver(adsCache *ads.AdsCache) (*DNSResolver, error) { +func NewDNSResolver() (*DNSResolver, error) { r := &DNSResolver{ - DnsResolverChan: make(chan []*clusterv3.Cluster), - cache: map[string]*domainCacheEntry{}, - adsCache: adsCache, + // DnsResolverChan: make(chan []*clusterv3.Cluster), + cache: map[string]*domainCacheEntry{}, + // adsCache: adsCache, dnsRefreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "dnsRefreshQueue"}), client: &dns.Client{ DialTimeout: 5 * time.Second, @@ -164,43 +92,27 @@ func NewDNSResolver(adsCache *ads.AdsCache) (*DNSResolver, error) { return r, nil } -func (r *DNSResolver) StartDNSResolver(stopCh <-chan struct{}) { - go r.startResolver() - go r.refreshWorker() - go func() { - <-stopCh - r.dnsRefreshQueue.ShutDown() - close(r.DnsResolverChan) - }() -} +// func (r *DNSResolver) StartDNSResolver(stopCh <-chan struct{}) { +// // go r.startResolver() +// // go r.refreshWorker() +// go func() { +// <-stopCh +// r.dnsRefreshQueue.ShutDown() +// // close(r.DnsResolverChan) +// }() +// } // startResolver watches the DnsResolver Channel -func (r *DNSResolver) startResolver() { - rateLimiter := make(chan struct{}, MaxConcurrency) - for clusters := range r.DnsResolverChan { - rateLimiter <- struct{}{} - go func(clusters []*clusterv3.Cluster) { - defer func() { <-rateLimiter }() - r.resolveDomains(clusters) - }(clusters) - } -} - -// resolveDomains takes a slice of cluster -func (r *DNSResolver) resolveDomains(clusters []*clusterv3.Cluster) { - domains := getPendingResolveDomain(clusters) - - // Stow domain updates, need to remove unwatched domains first - r.removeUnwatchedDomain(domains) - for _, v := range domains { - r.Lock() - if r.cache[v.domainName] == nil { - r.cache[v.domainName] = &domainCacheEntry{} - } - r.Unlock() - r.dnsRefreshQueue.AddAfter(v, 0) - } -} +// func (r *DNSResolver) startResolver() { +// rateLimiter := make(chan struct{}, MaxConcurrency) +// for clusters := range r.DnsResolverChan { +// rateLimiter <- struct{}{} +// go func(clusters []*clusterv3.Cluster) { +// defer func() { <-rateLimiter }() +// r.resolveDomains(clusters) +// }(clusters) +// } +// } // removeUnwatchedDomain cancels any scheduled re-resolve for names we no longer care about func (r *DNSResolver) removeUnwatchedDomain(domains map[string]*pendingResolveDomain) { @@ -223,76 +135,76 @@ func (r *DNSResolver) resolve(v *pendingResolveDomain) { r.RUnlock() return } - r.RUnlock() addrs, ttl, err := r.doResolve(v.domainName, v.refreshRate) - if err == nil { - // for the newly resolved domain just push to bpf map - log.Infof("resolve dns name: %s, addr: %v", v.domainName, addrs) - // refresh the dns address periodically by respecting the dnsRefreshRate and ttl, which one is shorter - if ttl > v.refreshRate { - ttl = v.refreshRate - } - if ttl == 0 { - ttl = DeRefreshInterval - } - if !slices.Equal(entry.addresses, addrs) { - for _, c := range v.clusters { - ready := overwriteDnsCluster(c, v.domainName, addrs) - if ready { - if !r.adsCache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, c) { - log.Debugf("cluster: %s is deleted", c.Name) - return - } - } - } - } - r.Lock() - entry.addresses = addrs - r.Unlock() - } else { - ttl = RetryAfter - log.Errorf("resolve domain %s failed: %v, retry after %v", v.domainName, err, ttl) + fmt.Printf("domainName is: %v, address is: %v\n", v.domainName, addrs) + if err != nil { + return } + r.RLock() + entry.addresses = addrs + r.RUnlock() + // push to refresh queue r.dnsRefreshQueue.AddAfter(v, ttl) + return + + // if err == nil { + // // for the newly resolved domain just push to bpf map + // log.Infof("resolve dns name: %s, addr: %v", v.domainName, addrs) + // // refresh the dns address periodically by respecting the dnsRefreshRate and ttl, which one is shorter + // if ttl > v.refreshRate { + // ttl = v.refreshRate + // } + // if ttl == 0 { + // ttl = DeRefreshInterval + // } + // if !slices.Equal(entry.addresses, addrs) { + // for _, c := range v.clusters { + // ready := overwriteDnsCluster(c, v.domainName, addrs) + // if ready { + // if !r.adsCache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, c) { + // log.Debugf("cluster: %s is deleted", c.Name) + // return + // } + // } + // } + // } + // r.Lock() + // entry.addresses = addrs + // r.Unlock() + // } else { + // ttl = RetryAfter + // log.Errorf("resolve domain %s failed: %v, retry after %v", v.domainName, err, ttl) + // } } -func (r *DNSResolver) refreshWorker() { - for r.refreshDNS() { - } -} +// func (r *DNSResolver) refreshWorker() { +// for r.refreshDNS() { +// } +// } // refreshDNS use a delay working queue to handle dns refresh -func (r *DNSResolver) refreshDNS() bool { - element, quit := r.dnsRefreshQueue.Get() - if quit { - return false - } - defer r.dnsRefreshQueue.Done(element) - dr := element.(*pendingResolveDomain) - r.RLock() - _, exist := r.cache[dr.domainName] - r.RUnlock() - // if the domain is no longer watched, no need to refresh it - if !exist { - return true - } - r.resolve(dr) - r.adsCache.ClusterCache.Flush() - return true -} - -func (r *DNSResolver) GetDNSAddresses(domain string) []string { - r.RLock() - defer r.RUnlock() - if entry, ok := r.cache[domain]; ok { - return entry.addresses - } - return nil -} +// func (r *DNSResolver) refreshDNS() bool { +// element, quit := r.dnsRefreshQueue.Get() +// if quit { +// return false +// } +// defer r.dnsRefreshQueue.Done(element) +// dr := element.(*pendingResolveDomain) +// r.RLock() +// _, exist := r.cache[dr.domainName] +// r.RUnlock() +// // if the domain is no longer watched, no need to refresh it +// if !exist { +// return true +// } +// r.resolve(dr) +// r.adsCache.ClusterCache.Flush() +// return true +// } func (r *DNSResolver) GetAllCachedDomains() []string { r.RLock() @@ -403,41 +315,3 @@ func getMinTTL(m *dns.Msg, refreshRate time.Duration) time.Duration { } return minTTL } - -// Get domain name and refreshrate from cluster, and also store cluster and port in the return addresses for later use -func getPendingResolveDomain(clusters []*clusterv3.Cluster) map[string]*pendingResolveDomain { - domains := make(map[string]*pendingResolveDomain) - - for _, cluster := range clusters { - if cluster.LoadAssignment == nil { - continue - } - - for _, e := range cluster.LoadAssignment.Endpoints { - for _, le := range e.LbEndpoints { - socketAddr, ok := le.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) - if !ok { - continue - } - address := socketAddr.SocketAddress.Address - if _, err := netip.ParseAddr(address); err == nil { - // This is an ip address - continue - } - - if v, ok := domains[address]; ok { - v.clusters = append(v.clusters, cluster) - } else { - domainWithRefreshRate := &pendingResolveDomain{ - domainName: address, - clusters: []*clusterv3.Cluster{cluster}, - refreshRate: cluster.GetDnsRefreshRate().AsDuration(), - } - domains[address] = domainWithRefreshRate - } - } - } - } - - return domains -} diff --git a/pkg/dns/dns_test.go b/pkg/dns/dns_test.go index 409e10d05..ed3ad1aef 100644 --- a/pkg/dns/dns_test.go +++ b/pkg/dns/dns_test.go @@ -19,7 +19,6 @@ package dns import ( "fmt" "math" - "math/rand" "net" "reflect" "sync" @@ -30,12 +29,6 @@ import ( v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/wrapperspb" - "istio.io/istio/pkg/slices" - "istio.io/istio/pkg/test/util/retry" - - core_v2 "kmesh.net/kmesh/api/v2/core" "kmesh.net/kmesh/pkg/controller/ads" ) @@ -49,17 +42,28 @@ type fakeDNSServer struct { hosts map[string]int } +func (r *DNSResolver) GetDNSAddresses(domain string) []string { + r.Lock() + defer r.Unlock() + if entry, ok := r.cache[domain]; ok { + return entry.addresses + } + return nil +} + func TestDNS(t *testing.T) { fakeDNSServer := newFakeDNSServer() - testDNSResolver, err := NewDNSResolver(ads.NewAdsCache(nil)) + // testDNSResolver, err := NewDNSResolver(ads.NewAdsCache(nil)) + testDNSResolver, err := NewAdsDnsResolver(ads.NewAdsCache(nil)) if err != nil { t.Fatal(err) } stopCh := make(chan struct{}) defer close(stopCh) - testDNSResolver.StartDNSResolver(stopCh) - testDNSResolver.resolvConfServers = []string{fakeDNSServer.Server.PacketConn.LocalAddr().String()} + // testDNSResolver.StartDNSResolver(stopCh) + testDNSResolver.StartAdsDnsResolver(stopCh) + testDNSResolver.DnsResolver.resolvConfServers = []string{fakeDNSServer.Server.PacketConn.LocalAddr().String()} testCases := []struct { name string @@ -139,15 +143,15 @@ func TestDNS(t *testing.T) { domainName: testcase.domain, refreshRate: testcase.refreshRate, } - testDNSResolver.Lock() - testDNSResolver.cache[testcase.domain] = &domainCacheEntry{} - testDNSResolver.Unlock() + testDNSResolver.DnsResolver.Lock() + testDNSResolver.DnsResolver.cache[testcase.domain] = &domainCacheEntry{} + testDNSResolver.DnsResolver.Unlock() - testDNSResolver.resolve(input) + testDNSResolver.DnsResolver.resolve(input) time.Sleep(2 * time.Second) - res := testDNSResolver.GetDNSAddresses(testcase.domain) + res := testDNSResolver.DnsResolver.GetDNSAddresses(testcase.domain) if len(res) != 0 || len(testcase.expected) != 0 { if !reflect.DeepEqual(res, testcase.expected) { t.Errorf("dns resolve for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expected) @@ -156,7 +160,7 @@ func TestDNS(t *testing.T) { if testcase.expectedAfterTTL != nil { ttl := time.Duration(math.Min(float64(testcase.ttl), float64(testcase.refreshRate))) time.Sleep(ttl + 1) - res = testDNSResolver.GetDNSAddresses(testcase.domain) + res = testDNSResolver.DnsResolver.GetDNSAddresses(testcase.domain) if !reflect.DeepEqual(res, testcase.expectedAfterTTL) { t.Errorf("dns refresh after ttl failed, for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expectedAfterTTL) } @@ -167,96 +171,6 @@ func TestDNS(t *testing.T) { wg.Wait() } -// This test aims to evaluate the concurrent writing behavior of the adsCache by utilizing the test race feature. -// The test verifies the ability of the adsCache to handle concurrent access and updates correctly in a multi-goroutine environment. -func TestADSCacheConcurrentWriting(t *testing.T) { - adsCache := ads.NewAdsCache(nil) - cluster := &clusterv3.Cluster{ - Name: "ut-cluster", - ClusterDiscoveryType: &clusterv3.Cluster_Type{ - Type: clusterv3.Cluster_LOGICAL_DNS, - }, - } - adsCache.CreateApiClusterByCds(core_v2.ApiStatus_NONE, cluster) - - var wg sync.WaitGroup - for i := 0; i < 100; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 100; j++ { - currentStatus := adsCache.GetApiClusterStatus(cluster.GetName()) - newStatus := currentStatus + core_v2.ApiStatus(rand.Intn(3)-1) - if rand.Intn(2) == 0 { - adsCache.UpdateApiClusterIfExists(newStatus, cluster) - } else { - adsCache.UpdateApiClusterStatus(cluster.GetName(), newStatus) - } - } - }() - } - - wg.Wait() -} - -func TestOverwriteDNSCluster(t *testing.T) { - domain := "www.google.com" - addrs := []string{"10.1.1.1", "10.1.1.2"} - cluster := &clusterv3.Cluster{ - Name: "ut-cluster", - ClusterDiscoveryType: &clusterv3.Cluster_Type{ - Type: clusterv3.Cluster_LOGICAL_DNS, - }, - LoadAssignment: &endpointv3.ClusterLoadAssignment{ - ClusterName: "ut-cluster", - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LoadBalancingWeight: wrapperspb.UInt32(30), - Priority: uint32(15), - LbEndpoints: []*endpointv3.LbEndpoint{ - { - HealthStatus: v3.HealthStatus_HEALTHY, - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: domain, - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: uint32(9898), - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - - overwriteDnsCluster(cluster, domain, addrs) - - endpoints := cluster.GetLoadAssignment().GetEndpoints()[0].GetLbEndpoints() - if len(endpoints) != 2 { - t.Errorf("Expected 2 LbEndpoints, but got %d", len(endpoints)) - } - out := []string{} - for _, e := range endpoints { - socketAddr, ok := e.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) - if !ok { - continue - } - address := socketAddr.SocketAddress.Address - out = append(out, address) - } - if !slices.Equal(out, addrs) { - t.Errorf("OverwriteDNSCluster error, expected %v, but got %v", out, addrs) - } -} - func newFakeDNSServer() *fakeDNSServer { var wg sync.WaitGroup wg.Add(1) @@ -425,99 +339,99 @@ func TestGetPendingResolveDomain(t *testing.T) { } } -func TestHandleCdsResponseWithDns(t *testing.T) { - cluster1 := &clusterv3.Cluster{ - Name: "ut-cluster1", - ClusterDiscoveryType: &clusterv3.Cluster_Type{ - Type: clusterv3.Cluster_LOGICAL_DNS, - }, - LoadAssignment: &endpointv3.ClusterLoadAssignment{ - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LbEndpoints: []*endpointv3.LbEndpoint{ - { - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: "foo.bar", - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: uint32(9898), - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - cluster2 := &clusterv3.Cluster{ - Name: "ut-cluster2", - ClusterDiscoveryType: &clusterv3.Cluster_Type{ - Type: clusterv3.Cluster_STRICT_DNS, - }, - LoadAssignment: &endpointv3.ClusterLoadAssignment{ - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LbEndpoints: []*endpointv3.LbEndpoint{ - { - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: "foo.baz", - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: uint32(9898), - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - - testcases := []struct { - name string - clusters []*clusterv3.Cluster - expected []string - }{ - { - name: "add clusters with DNS type", - clusters: []*clusterv3.Cluster{cluster1, cluster2}, - expected: []string{"foo.bar", "foo.baz"}, - }, - { - name: "remove all DNS type clusters", - clusters: []*clusterv3.Cluster{}, - expected: []string{}, - }, - } - - p := ads.NewController(nil).Processor - stopCh := make(chan struct{}) - defer close(stopCh) - dnsResolver, err := NewDNSResolver(ads.NewAdsCache(nil)) - assert.NoError(t, err) - dnsResolver.StartDNSResolver(stopCh) - p.DnsResolverChan = dnsResolver.DnsResolverChan - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - // notify dns resolver - dnsResolver.DnsResolverChan <- tc.clusters - retry.UntilOrFail(t, func() bool { - return slices.EqualUnordered(tc.expected, dnsResolver.GetAllCachedDomains()) - }, retry.Timeout(1*time.Second)) - }) - } -} +// func TestHandleCdsResponseWithDns(t *testing.T) { +// cluster1 := &clusterv3.Cluster{ +// Name: "ut-cluster1", +// ClusterDiscoveryType: &clusterv3.Cluster_Type{ +// Type: clusterv3.Cluster_LOGICAL_DNS, +// }, +// LoadAssignment: &endpointv3.ClusterLoadAssignment{ +// Endpoints: []*endpointv3.LocalityLbEndpoints{ +// { +// LbEndpoints: []*endpointv3.LbEndpoint{ +// { +// HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ +// Endpoint: &endpointv3.Endpoint{ +// Address: &v3.Address{ +// Address: &v3.Address_SocketAddress{ +// SocketAddress: &v3.SocketAddress{ +// Address: "foo.bar", +// PortSpecifier: &v3.SocketAddress_PortValue{ +// PortValue: uint32(9898), +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// } +// cluster2 := &clusterv3.Cluster{ +// Name: "ut-cluster2", +// ClusterDiscoveryType: &clusterv3.Cluster_Type{ +// Type: clusterv3.Cluster_STRICT_DNS, +// }, +// LoadAssignment: &endpointv3.ClusterLoadAssignment{ +// Endpoints: []*endpointv3.LocalityLbEndpoints{ +// { +// LbEndpoints: []*endpointv3.LbEndpoint{ +// { +// HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ +// Endpoint: &endpointv3.Endpoint{ +// Address: &v3.Address{ +// Address: &v3.Address_SocketAddress{ +// SocketAddress: &v3.SocketAddress{ +// Address: "foo.baz", +// PortSpecifier: &v3.SocketAddress_PortValue{ +// PortValue: uint32(9898), +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// }, +// } + +// testcases := []struct { +// name string +// clusters []*clusterv3.Cluster +// expected []string +// }{ +// { +// name: "add clusters with DNS type", +// clusters: []*clusterv3.Cluster{cluster1, cluster2}, +// expected: []string{"foo.bar", "foo.baz"}, +// }, +// { +// name: "remove all DNS type clusters", +// clusters: []*clusterv3.Cluster{}, +// expected: []string{}, +// }, +// } + +// p := ads.NewController(nil).Processor +// stopCh := make(chan struct{}) +// defer close(stopCh) +// dnsResolver, err := NewDNSResolver() +// assert.NoError(t, err) +// dnsResolver.StartDNSResolver(stopCh) +// p.DnsResolverChan = dnsResolver.DnsResolverChan +// for _, tc := range testcases { +// t.Run(tc.name, func(t *testing.T) { +// // notify dns resolver +// dnsResolver.DnsResolverChan <- tc.clusters +// retry.UntilOrFail(t, func() bool { +// return slices.EqualUnordered(tc.expected, dnsResolver.GetAllCachedDomains()) +// }, retry.Timeout(1*time.Second)) +// }) +// } +// } From 4b084733bf70e9bc4228a896c51deb1cd8cf3ca3 Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Fri, 21 Feb 2025 11:27:26 +0800 Subject: [PATCH 02/11] modify UT of dns Signed-off-by: LiZhenCheng9527 --- pkg/controller/controller.go | 13 ++-- pkg/dns/ads_handler.go | 65 ++++---------------- pkg/dns/ads_handler_test.go | 103 ++++++++++++++++++++++++++++++- pkg/dns/dns.go | 113 ++++++----------------------------- pkg/dns/dns_test.go | 109 ++++----------------------------- 5 files changed, 151 insertions(+), 252 deletions(-) diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index c5be6eca5..2d1805a7f 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -31,6 +31,7 @@ import ( "kmesh.net/kmesh/pkg/controller/encryption/ipsec" manage "kmesh.net/kmesh/pkg/controller/manage" "kmesh.net/kmesh/pkg/controller/security" + "kmesh.net/kmesh/pkg/dns" "kmesh.net/kmesh/pkg/kolog" "kmesh.net/kmesh/pkg/kube" "kmesh.net/kmesh/pkg/logger" @@ -155,12 +156,12 @@ func (c *Controller) Start(stopCh <-chan struct{}) error { } if c.client.AdsController != nil { - // dnsResolver, err := dns.NewDNSResolver(c.client.AdsController.Processor.Cache) - // if err != nil { - // return fmt.Errorf("dns resolver create failed: %v", err) - // } - // dnsResolver.StartDNSResolver(stopCh) - // c.client.AdsController.Processor.DnsResolverChan = dnsResolver.DnsResolverChan + dnsResolver, err := dns.NewAdsDnsResolver(c.client.AdsController.Processor.Cache) + if err != nil { + return fmt.Errorf("dns resolver of Kernel-Native mode create failed: %v", err) + } + dnsResolver.StartAdsDnsResolver(stopCh) + c.client.AdsController.Processor.DnsResolverChan = dnsResolver.Clusters } return c.client.Run(stopCh) diff --git a/pkg/dns/ads_handler.go b/pkg/dns/ads_handler.go index 9a41d6205..a7af371c7 100644 --- a/pkg/dns/ads_handler.go +++ b/pkg/dns/ads_handler.go @@ -20,6 +20,7 @@ import ( "net" "net/netip" "slices" + "time" clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -86,12 +87,9 @@ func (adsResolver *AdsDnsResolver) refreshAdsDns() bool { if !exist { return true } - // addrs, _, err := adsResolver.DnsResolver.resolve(e) - // if err != nil { - // log.Errorf("dns error is: %v", err) - // } - // entry.addresses = addrs - adsResolver.DnsResolver.resolve(e) + // adsResolver.DnsResolver.resolve(e) + addres, ttl := adsResolver.DnsResolver.resolve(e) + adsResolver.adsDnsResolve(e, addres, ttl) adsResolver.adsCache.ClusterCache.Flush() return true } @@ -101,62 +99,23 @@ func (adsResolver *AdsDnsResolver) refreshAdsWorker() { } } -// resolveDomains takes a slice of cluster -func (r *DNSResolver) resolveDomains(clusters []*clusterv3.Cluster) { - domains := getPendingResolveDomain(clusters) - - // Stow domain updates, need to remove unwatched domains first - r.removeUnwatchedDomain(domains) - for _, v := range domains { - r.Lock() - if r.cache[v.domainName] == nil { - r.cache[v.domainName] = &domainCacheEntry{} - } - r.Unlock() - r.dnsRefreshQueue.AddAfter(v, 0) - } -} - -func (adsResolver *AdsDnsResolver) adsDnsResolve(domain *pendingResolveDomain) { - adsResolver.DnsResolver.RLock() - entry := adsResolver.DnsResolver.cache[domain.domainName] - // This can happen when the domain is deleted before the refresher tick reaches - if entry == nil { - adsResolver.DnsResolver.RUnlock() - log.Errorf("domain is not in cache") - return - } - - adsResolver.DnsResolver.RUnlock() - - addrs, ttl, err := adsResolver.DnsResolver.doResolve(domain.domainName, domain.refreshRate) - if err != nil { - log.Errorf("failed to resolve: %v, err: %v", domain.domainName, err) - } - - // for the newly resolved domain just push to bpf map - log.Infof("resolve dns name: %s, addr: %v", domain.domainName, addrs) - // refresh the dns address periodically by respecting the dnsRefreshRate and ttl, which one is shorter +func (adsResolver *AdsDnsResolver) adsDnsResolve(domain *pendingResolveDomain, addrs []string, ttl time.Duration) { if ttl > domain.refreshRate { ttl = domain.refreshRate } if ttl == 0 { ttl = DeRefreshInterval } - if !slices.Equal(entry.addresses, addrs) { - for _, c := range domain.clusters { - ready := overwriteDnsCluster(c, domain.domainName, addrs) - if ready { - if !adsResolver.adsCache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, c) { - log.Debugf("cluster: %s is deleted", c.Name) - return - } + for _, c := range domain.clusters { + ready := overwriteDnsCluster(c, domain.domainName, addrs) + if ready { + if !adsResolver.adsCache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, c) { + log.Debugf("cluster: %s is deleted", c.Name) + return } } } - adsResolver.DnsResolver.Lock() - entry.addresses = addrs - adsResolver.DnsResolver.Unlock() + return } func overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []string) bool { diff --git a/pkg/dns/ads_handler_test.go b/pkg/dns/ads_handler_test.go index c9faa1558..9fc35283e 100644 --- a/pkg/dns/ads_handler_test.go +++ b/pkg/dns/ads_handler_test.go @@ -18,14 +18,18 @@ package dns import ( "math/rand" - "slices" "sync" "testing" + "time" clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/wrapperspb" + "istio.io/istio/pkg/slices" + "istio.io/istio/pkg/test/util/retry" + core_v2 "kmesh.net/kmesh/api/v2/core" "kmesh.net/kmesh/pkg/controller/ads" ) @@ -122,3 +126,100 @@ func TestADSCacheConcurrentWriting(t *testing.T) { wg.Wait() } + +func TestHandleCdsResponseWithDns(t *testing.T) { + cluster1 := &clusterv3.Cluster{ + Name: "ut-cluster1", + ClusterDiscoveryType: &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_LOGICAL_DNS, + }, + LoadAssignment: &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + { + LbEndpoints: []*endpointv3.LbEndpoint{ + { + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: "foo.bar", + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: uint32(9898), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + cluster2 := &clusterv3.Cluster{ + Name: "ut-cluster2", + ClusterDiscoveryType: &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_STRICT_DNS, + }, + LoadAssignment: &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + { + LbEndpoints: []*endpointv3.LbEndpoint{ + { + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: "foo.baz", + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: uint32(9898), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + testcases := []struct { + name string + clusters []*clusterv3.Cluster + expected []string + }{ + { + name: "add clusters with DNS type", + clusters: []*clusterv3.Cluster{cluster1, cluster2}, + expected: []string{"foo.bar", "foo.baz"}, + }, + { + name: "remove all DNS type clusters", + clusters: []*clusterv3.Cluster{}, + expected: []string{}, + }, + } + + p := ads.NewController(nil).Processor + stopCh := make(chan struct{}) + defer close(stopCh) + dnsResolver, err := NewAdsDnsResolver(p.Cache) + assert.NoError(t, err) + dnsResolver.StartAdsDnsResolver(stopCh) + p.DnsResolverChan = dnsResolver.Clusters + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // notify dns resolver + dnsResolver.Clusters <- tc.clusters + retry.UntilOrFail(t, func() bool { + return slices.EqualUnordered(tc.expected, dnsResolver.DnsResolver.GetAllCachedDomains()) + }, retry.Timeout(1*time.Second)) + }) + } +} diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index 4561504cc..cbbf0a189 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -42,12 +42,9 @@ const ( ) type DNSResolver struct { - // DnsResolverChan chan []*clusterv3.Cluster client *dns.Client resolvConfServers []string cache map[string]*domainCacheEntry - // adsCache is used for update bpf map - // adsCache *ads.AdsCache // dns refresh priority queue based on exp dnsRefreshQueue workqueue.TypedDelayingInterface[any] sync.RWMutex @@ -68,9 +65,7 @@ type pendingResolveDomain struct { func NewDNSResolver() (*DNSResolver, error) { r := &DNSResolver{ - // DnsResolverChan: make(chan []*clusterv3.Cluster), - cache: map[string]*domainCacheEntry{}, - // adsCache: adsCache, + cache: map[string]*domainCacheEntry{}, dnsRefreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "dnsRefreshQueue"}), client: &dns.Client{ DialTimeout: 5 * time.Second, @@ -92,28 +87,6 @@ func NewDNSResolver() (*DNSResolver, error) { return r, nil } -// func (r *DNSResolver) StartDNSResolver(stopCh <-chan struct{}) { -// // go r.startResolver() -// // go r.refreshWorker() -// go func() { -// <-stopCh -// r.dnsRefreshQueue.ShutDown() -// // close(r.DnsResolverChan) -// }() -// } - -// startResolver watches the DnsResolver Channel -// func (r *DNSResolver) startResolver() { -// rateLimiter := make(chan struct{}, MaxConcurrency) -// for clusters := range r.DnsResolverChan { -// rateLimiter <- struct{}{} -// go func(clusters []*clusterv3.Cluster) { -// defer func() { <-rateLimiter }() -// r.resolveDomains(clusters) -// }(clusters) -// } -// } - // removeUnwatchedDomain cancels any scheduled re-resolve for names we no longer care about func (r *DNSResolver) removeUnwatchedDomain(domains map[string]*pendingResolveDomain) { r.Lock() @@ -127,20 +100,20 @@ func (r *DNSResolver) removeUnwatchedDomain(domains map[string]*pendingResolveDo } // This functions were copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. -func (r *DNSResolver) resolve(v *pendingResolveDomain) { +func (r *DNSResolver) resolve(v *pendingResolveDomain) ([]string, time.Duration) { r.RLock() entry := r.cache[v.domainName] // This can happen when the domain is deleted before the refresher tick reaches if entry == nil { r.RUnlock() - return + return []string{}, time.Duration(0) } r.RUnlock() addrs, ttl, err := r.doResolve(v.domainName, v.refreshRate) - fmt.Printf("domainName is: %v, address is: %v\n", v.domainName, addrs) if err != nil { - return + log.Errorf("dns resolve failed: %v", err) + return []string{}, time.Duration(0) } r.RLock() @@ -149,71 +122,23 @@ func (r *DNSResolver) resolve(v *pendingResolveDomain) { // push to refresh queue r.dnsRefreshQueue.AddAfter(v, ttl) - return - - // if err == nil { - // // for the newly resolved domain just push to bpf map - // log.Infof("resolve dns name: %s, addr: %v", v.domainName, addrs) - // // refresh the dns address periodically by respecting the dnsRefreshRate and ttl, which one is shorter - // if ttl > v.refreshRate { - // ttl = v.refreshRate - // } - // if ttl == 0 { - // ttl = DeRefreshInterval - // } - // if !slices.Equal(entry.addresses, addrs) { - // for _, c := range v.clusters { - // ready := overwriteDnsCluster(c, v.domainName, addrs) - // if ready { - // if !r.adsCache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, c) { - // log.Debugf("cluster: %s is deleted", c.Name) - // return - // } - // } - // } - // } - // r.Lock() - // entry.addresses = addrs - // r.Unlock() - // } else { - // ttl = RetryAfter - // log.Errorf("resolve domain %s failed: %v, retry after %v", v.domainName, err, ttl) - // } + return addrs, ttl } -// func (r *DNSResolver) refreshWorker() { -// for r.refreshDNS() { -// } -// } - -// refreshDNS use a delay working queue to handle dns refresh -// func (r *DNSResolver) refreshDNS() bool { -// element, quit := r.dnsRefreshQueue.Get() -// if quit { -// return false -// } -// defer r.dnsRefreshQueue.Done(element) -// dr := element.(*pendingResolveDomain) -// r.RLock() -// _, exist := r.cache[dr.domainName] -// r.RUnlock() -// // if the domain is no longer watched, no need to refresh it -// if !exist { -// return true -// } -// r.resolve(dr) -// r.adsCache.ClusterCache.Flush() -// return true -// } - -func (r *DNSResolver) GetAllCachedDomains() []string { - r.RLock() - defer r.RUnlock() - out := make([]string, 0, len(r.cache)) - for domain := range r.cache { - out = append(out, domain) +// resolveDomains takes a slice of cluster +func (r *DNSResolver) resolveDomains(clusters []*clusterv3.Cluster) { + domains := getPendingResolveDomain(clusters) + + // Stow domain updates, need to remove unwatched domains first + r.removeUnwatchedDomain(domains) + for _, v := range domains { + r.Lock() + if r.cache[v.domainName] == nil { + r.cache[v.domainName] = &domainCacheEntry{} + } + r.Unlock() + r.dnsRefreshQueue.AddAfter(v, 0) } - return out } // doResolve is copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. diff --git a/pkg/dns/dns_test.go b/pkg/dns/dns_test.go index ed3ad1aef..6f9c5d6d5 100644 --- a/pkg/dns/dns_test.go +++ b/pkg/dns/dns_test.go @@ -29,6 +29,7 @@ import ( v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" "github.com/miekg/dns" + "kmesh.net/kmesh/pkg/controller/ads" ) @@ -61,7 +62,6 @@ func TestDNS(t *testing.T) { } stopCh := make(chan struct{}) defer close(stopCh) - // testDNSResolver.StartDNSResolver(stopCh) testDNSResolver.StartAdsDnsResolver(stopCh) testDNSResolver.DnsResolver.resolvConfServers = []string{fakeDNSServer.Server.PacketConn.LocalAddr().String()} @@ -229,6 +229,16 @@ func (s *fakeDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } +func (r *DNSResolver) GetAllCachedDomains() []string { + r.RLock() + defer r.RUnlock() + out := make([]string, 0, len(r.cache)) + for domain := range r.cache { + out = append(out, domain) + } + return out +} + func (s *fakeDNSServer) setHosts(domain string, surfix int) { s.mu.Lock() defer s.mu.Unlock() @@ -338,100 +348,3 @@ func TestGetPendingResolveDomain(t *testing.T) { }) } } - -// func TestHandleCdsResponseWithDns(t *testing.T) { -// cluster1 := &clusterv3.Cluster{ -// Name: "ut-cluster1", -// ClusterDiscoveryType: &clusterv3.Cluster_Type{ -// Type: clusterv3.Cluster_LOGICAL_DNS, -// }, -// LoadAssignment: &endpointv3.ClusterLoadAssignment{ -// Endpoints: []*endpointv3.LocalityLbEndpoints{ -// { -// LbEndpoints: []*endpointv3.LbEndpoint{ -// { -// HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ -// Endpoint: &endpointv3.Endpoint{ -// Address: &v3.Address{ -// Address: &v3.Address_SocketAddress{ -// SocketAddress: &v3.SocketAddress{ -// Address: "foo.bar", -// PortSpecifier: &v3.SocketAddress_PortValue{ -// PortValue: uint32(9898), -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// } -// cluster2 := &clusterv3.Cluster{ -// Name: "ut-cluster2", -// ClusterDiscoveryType: &clusterv3.Cluster_Type{ -// Type: clusterv3.Cluster_STRICT_DNS, -// }, -// LoadAssignment: &endpointv3.ClusterLoadAssignment{ -// Endpoints: []*endpointv3.LocalityLbEndpoints{ -// { -// LbEndpoints: []*endpointv3.LbEndpoint{ -// { -// HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ -// Endpoint: &endpointv3.Endpoint{ -// Address: &v3.Address{ -// Address: &v3.Address_SocketAddress{ -// SocketAddress: &v3.SocketAddress{ -// Address: "foo.baz", -// PortSpecifier: &v3.SocketAddress_PortValue{ -// PortValue: uint32(9898), -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// }, -// } - -// testcases := []struct { -// name string -// clusters []*clusterv3.Cluster -// expected []string -// }{ -// { -// name: "add clusters with DNS type", -// clusters: []*clusterv3.Cluster{cluster1, cluster2}, -// expected: []string{"foo.bar", "foo.baz"}, -// }, -// { -// name: "remove all DNS type clusters", -// clusters: []*clusterv3.Cluster{}, -// expected: []string{}, -// }, -// } - -// p := ads.NewController(nil).Processor -// stopCh := make(chan struct{}) -// defer close(stopCh) -// dnsResolver, err := NewDNSResolver() -// assert.NoError(t, err) -// dnsResolver.StartDNSResolver(stopCh) -// p.DnsResolverChan = dnsResolver.DnsResolverChan -// for _, tc := range testcases { -// t.Run(tc.name, func(t *testing.T) { -// // notify dns resolver -// dnsResolver.DnsResolverChan <- tc.clusters -// retry.UntilOrFail(t, func() bool { -// return slices.EqualUnordered(tc.expected, dnsResolver.GetAllCachedDomains()) -// }, retry.Timeout(1*time.Second)) -// }) -// } -// } From c41a2e132f50f4ba8ac3645ca4bcf0e9c4df5732 Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Fri, 21 Feb 2025 14:38:09 +0800 Subject: [PATCH 03/11] go lint Signed-off-by: LiZhenCheng9527 --- pkg/dns/ads_handler.go | 18 +++++++----------- pkg/dns/dns.go | 16 +++++++++++----- pkg/dns/dns_test.go | 41 +++++++++++++++++++++-------------------- 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/pkg/dns/ads_handler.go b/pkg/dns/ads_handler.go index a7af371c7..c288780d7 100644 --- a/pkg/dns/ads_handler.go +++ b/pkg/dns/ads_handler.go @@ -20,7 +20,6 @@ import ( "net" "net/netip" "slices" - "time" clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -88,8 +87,12 @@ func (adsResolver *AdsDnsResolver) refreshAdsDns() bool { return true } // adsResolver.DnsResolver.resolve(e) - addres, ttl := adsResolver.DnsResolver.resolve(e) - adsResolver.adsDnsResolve(e, addres, ttl) + addresses, err := adsResolver.DnsResolver.resolve(e) + if err != nil { + log.Errorf("failed to dns resolve: %v", err) + return false + } + adsResolver.adsDnsResolve(e, addresses) adsResolver.adsCache.ClusterCache.Flush() return true } @@ -99,13 +102,7 @@ func (adsResolver *AdsDnsResolver) refreshAdsWorker() { } } -func (adsResolver *AdsDnsResolver) adsDnsResolve(domain *pendingResolveDomain, addrs []string, ttl time.Duration) { - if ttl > domain.refreshRate { - ttl = domain.refreshRate - } - if ttl == 0 { - ttl = DeRefreshInterval - } +func (adsResolver *AdsDnsResolver) adsDnsResolve(domain *pendingResolveDomain, addrs []string) { for _, c := range domain.clusters { ready := overwriteDnsCluster(c, domain.domainName, addrs) if ready { @@ -115,7 +112,6 @@ func (adsResolver *AdsDnsResolver) adsDnsResolve(domain *pendingResolveDomain, a } } } - return } func overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []string) bool { diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index cbbf0a189..b94a2445c 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -100,20 +100,26 @@ func (r *DNSResolver) removeUnwatchedDomain(domains map[string]*pendingResolveDo } // This functions were copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. -func (r *DNSResolver) resolve(v *pendingResolveDomain) ([]string, time.Duration) { +func (r *DNSResolver) resolve(v *pendingResolveDomain) ([]string, error) { r.RLock() entry := r.cache[v.domainName] // This can happen when the domain is deleted before the refresher tick reaches if entry == nil { r.RUnlock() - return []string{}, time.Duration(0) + return []string{}, fmt.Errorf("cache entry for domain %s not found", v.domainName) } r.RUnlock() addrs, ttl, err := r.doResolve(v.domainName, v.refreshRate) if err != nil { - log.Errorf("dns resolve failed: %v", err) - return []string{}, time.Duration(0) + return []string{}, fmt.Errorf("dns resolve failed: %v", err) + } + + if ttl > v.refreshRate { + ttl = v.refreshRate + } + if ttl == 0 { + ttl = DeRefreshInterval } r.RLock() @@ -122,7 +128,7 @@ func (r *DNSResolver) resolve(v *pendingResolveDomain) ([]string, time.Duration) // push to refresh queue r.dnsRefreshQueue.AddAfter(v, ttl) - return addrs, ttl + return addrs, nil } // resolveDomains takes a slice of cluster diff --git a/pkg/dns/dns_test.go b/pkg/dns/dns_test.go index 6f9c5d6d5..0cef24d64 100644 --- a/pkg/dns/dns_test.go +++ b/pkg/dns/dns_test.go @@ -43,15 +43,6 @@ type fakeDNSServer struct { hosts map[string]int } -func (r *DNSResolver) GetDNSAddresses(domain string) []string { - r.Lock() - defer r.Unlock() - if entry, ok := r.cache[domain]; ok { - return entry.addresses - } - return nil -} - func TestDNS(t *testing.T) { fakeDNSServer := newFakeDNSServer() @@ -63,7 +54,8 @@ func TestDNS(t *testing.T) { stopCh := make(chan struct{}) defer close(stopCh) testDNSResolver.StartAdsDnsResolver(stopCh) - testDNSResolver.DnsResolver.resolvConfServers = []string{fakeDNSServer.Server.PacketConn.LocalAddr().String()} + dnsServer := fakeDNSServer.Server.PacketConn.LocalAddr().String() + testDNSResolver.DnsResolver.resolvConfServers = []string{dnsServer} testCases := []struct { name string @@ -229,16 +221,6 @@ func (s *fakeDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } -func (r *DNSResolver) GetAllCachedDomains() []string { - r.RLock() - defer r.RUnlock() - out := make([]string, 0, len(r.cache)) - for domain := range r.cache { - out = append(out, domain) - } - return out -} - func (s *fakeDNSServer) setHosts(domain string, surfix int) { s.mu.Lock() defer s.mu.Unlock() @@ -251,6 +233,25 @@ func (s *fakeDNSServer) setTTL(ttl uint32) { s.ttl = ttl } +func (r *DNSResolver) GetAllCachedDomains() []string { + r.RLock() + defer r.RUnlock() + out := make([]string, 0, len(r.cache)) + for domain := range r.cache { + out = append(out, domain) + } + return out +} + +func (r *DNSResolver) GetDNSAddresses(domain string) []string { + r.Lock() + defer r.Unlock() + if entry, ok := r.cache[domain]; ok { + return entry.addresses + } + return nil +} + func TestGetPendingResolveDomain(t *testing.T) { utCluster := clusterv3.Cluster{ Name: "testCluster", From 72c9a32aae596bb23f6d57ea2bf14e72461b2a2e Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Fri, 21 Feb 2025 15:17:48 +0800 Subject: [PATCH 04/11] delete unused struct Signed-off-by: LiZhenCheng9527 --- .../ads_handler.go => controller/ads/dns.go} | 91 ++-- pkg/controller/ads/dns_test.go | 447 ++++++++++++++++++ pkg/controller/controller.go | 4 +- pkg/dns/ads_handler_test.go | 225 --------- pkg/dns/dns.go | 89 ++-- pkg/dns/dns_test.go | 351 -------------- pkg/dns/utils.go | 114 +++++ 7 files changed, 658 insertions(+), 663 deletions(-) rename pkg/{dns/ads_handler.go => controller/ads/dns.go} (65%) create mode 100644 pkg/controller/ads/dns_test.go delete mode 100644 pkg/dns/ads_handler_test.go delete mode 100644 pkg/dns/dns_test.go create mode 100644 pkg/dns/utils.go diff --git a/pkg/dns/ads_handler.go b/pkg/controller/ads/dns.go similarity index 65% rename from pkg/dns/ads_handler.go rename to pkg/controller/ads/dns.go index c288780d7..46b6cac0f 100644 --- a/pkg/dns/ads_handler.go +++ b/pkg/controller/ads/dns.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package dns +package ads import ( "net" @@ -25,27 +25,30 @@ import ( v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" "google.golang.org/protobuf/types/known/wrapperspb" + "k8s.io/client-go/util/workqueue" core_v2 "kmesh.net/kmesh/api/v2/core" - "kmesh.net/kmesh/pkg/controller/ads" + "kmesh.net/kmesh/pkg/dns" ) // adsDnsResolver is DNS resolver of Kernel Native type AdsDnsResolver struct { - Clusters chan []*clusterv3.Cluster - adsCache *ads.AdsCache - DnsResolver *DNSResolver + Clusters chan []*clusterv3.Cluster + adsCache *AdsCache + dnsResolver *dns.DNSResolver + dnsRefreshQueue workqueue.TypedDelayingInterface[any] } -func NewAdsDnsResolver(adsCache *ads.AdsCache) (*AdsDnsResolver, error) { - resolver, err := NewDNSResolver() +func NewAdsDnsResolver(adsCache *AdsCache) (*AdsDnsResolver, error) { + resolver, err := dns.NewDNSResolver() if err != nil { return nil, err } return &AdsDnsResolver{ - Clusters: make(chan []*clusterv3.Cluster), - adsCache: adsCache, - DnsResolver: resolver, + Clusters: make(chan []*clusterv3.Cluster), + dnsRefreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "dnsRefreshQueue"}), + adsCache: adsCache, + dnsResolver: resolver, }, nil } @@ -54,60 +57,79 @@ func (adsResolver *AdsDnsResolver) StartAdsDnsResolver(stopCh <-chan struct{}) { go adsResolver.refreshAdsWorker() go func() { <-stopCh - adsResolver.DnsResolver.dnsRefreshQueue.ShutDown() + adsResolver.dnsRefreshQueue.ShutDown() close(adsResolver.Clusters) }() } func (adsResolver *AdsDnsResolver) startAdsResolver() { - rateLimiter := make(chan struct{}, MaxConcurrency) + rateLimiter := make(chan struct{}, dns.MaxConcurrency) for clusters := range adsResolver.Clusters { rateLimiter <- struct{}{} go func(clusters []*clusterv3.Cluster) { defer func() { <-rateLimiter }() - adsResolver.DnsResolver.resolveDomains(clusters) + adsResolver.resolveDomains(clusters) }(clusters) } } func (adsResolver *AdsDnsResolver) refreshAdsDns() bool { - element, quit := adsResolver.DnsResolver.dnsRefreshQueue.Get() + element, quit := adsResolver.dnsRefreshQueue.Get() if quit { return false } - defer adsResolver.DnsResolver.dnsRefreshQueue.Done(element) - e := element.(*pendingResolveDomain) - adsResolver.DnsResolver.RLock() - _, exist := adsResolver.DnsResolver.cache[e.domainName] - adsResolver.DnsResolver.RUnlock() + defer adsResolver.dnsRefreshQueue.Done(element) + e := element.(*dns.PendingResolveDomain) + + adsResolver.dnsResolver.RLock() + _, exist := adsResolver.dnsResolver.Cache[e.DomainName] + adsResolver.dnsResolver.RUnlock() // if the domain is no longer watched, no need to refresh it if !exist { return true } - // adsResolver.DnsResolver.resolve(e) - addresses, err := adsResolver.DnsResolver.resolve(e) + addresses, ttl, err := adsResolver.dnsResolver.Resolve(e.DomainName) if err != nil { log.Errorf("failed to dns resolve: %v", err) return false } + if ttl > e.RefreshRate { + ttl = e.RefreshRate + } + if ttl == 0 { + ttl = dns.DeRefreshInterval + } + adsResolver.dnsRefreshQueue.AddAfter(e, ttl) + adsResolver.adsDnsResolve(e, addresses) adsResolver.adsCache.ClusterCache.Flush() return true } +func (adsResolver *AdsDnsResolver) resolveDomains(cds []*clusterv3.Cluster) { + domains := getPendingResolveDomain(cds) + + // Stow domain updates, need to remove unwatched domains first + adsResolver.dnsResolver.RemoveUnwatchedDomain(domains) + for k, v := range domains { + adsResolver.dnsResolver.ResolveDomains(k) + adsResolver.dnsRefreshQueue.AddAfter(v, 0) + } +} + func (adsResolver *AdsDnsResolver) refreshAdsWorker() { for adsResolver.refreshAdsDns() { } } -func (adsResolver *AdsDnsResolver) adsDnsResolve(domain *pendingResolveDomain, addrs []string) { - for _, c := range domain.clusters { - ready := overwriteDnsCluster(c, domain.domainName, addrs) +func (adsResolver *AdsDnsResolver) adsDnsResolve(pendingDomain *dns.PendingResolveDomain, addrs []string) { + for _, cluster := range pendingDomain.Clusters { + ready := overwriteDnsCluster(cluster, pendingDomain.DomainName, addrs) if ready { - if !adsResolver.adsCache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, c) { - log.Debugf("cluster: %s is deleted", c.Name) + if !adsResolver.adsCache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, cluster) { + log.Debugf("cluster: %s is deleted", cluster.Name) return } } @@ -179,11 +201,10 @@ func overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []stri return ready } -// Get domain name and refreshrate from cluster, and also store cluster and port in the return addresses for later use -func getPendingResolveDomain(clusters []*clusterv3.Cluster) map[string]*pendingResolveDomain { - domains := make(map[string]*pendingResolveDomain) +func getPendingResolveDomain(cds []*clusterv3.Cluster) map[string]*dns.PendingResolveDomain { + domains := make(map[string]*dns.PendingResolveDomain) - for _, cluster := range clusters { + for _, cluster := range cds { if cluster.LoadAssignment == nil { continue } @@ -201,12 +222,12 @@ func getPendingResolveDomain(clusters []*clusterv3.Cluster) map[string]*pendingR } if v, ok := domains[address]; ok { - v.clusters = append(v.clusters, cluster) + v.Clusters = append(v.Clusters, cluster) } else { - domainWithRefreshRate := &pendingResolveDomain{ - domainName: address, - clusters: []*clusterv3.Cluster{cluster}, - refreshRate: cluster.GetDnsRefreshRate().AsDuration(), + domainWithRefreshRate := &dns.PendingResolveDomain{ + DomainName: address, + Clusters: []*clusterv3.Cluster{cluster}, + RefreshRate: cluster.GetDnsRefreshRate().AsDuration(), } domains[address] = domainWithRefreshRate } diff --git a/pkg/controller/ads/dns_test.go b/pkg/controller/ads/dns_test.go new file mode 100644 index 000000000..e135e0a05 --- /dev/null +++ b/pkg/controller/ads/dns_test.go @@ -0,0 +1,447 @@ +/* + * Copyright The Kmesh Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ads + +import ( + "math/rand" + "reflect" + "sync" + "testing" + "time" + + clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/wrapperspb" + "istio.io/istio/pkg/slices" + "istio.io/istio/pkg/test/util/retry" + + core_v2 "kmesh.net/kmesh/api/v2/core" + "kmesh.net/kmesh/pkg/dns" +) + +func TestOverwriteDNSCluster(t *testing.T) { + domain := "www.google.com" + addrs := []string{"10.1.1.1", "10.1.1.2"} + cluster := &clusterv3.Cluster{ + Name: "ut-cluster", + ClusterDiscoveryType: &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_LOGICAL_DNS, + }, + LoadAssignment: &endpointv3.ClusterLoadAssignment{ + ClusterName: "ut-cluster", + Endpoints: []*endpointv3.LocalityLbEndpoints{ + { + LoadBalancingWeight: wrapperspb.UInt32(30), + Priority: uint32(15), + LbEndpoints: []*endpointv3.LbEndpoint{ + { + HealthStatus: v3.HealthStatus_HEALTHY, + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: domain, + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: uint32(9898), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + overwriteDnsCluster(cluster, domain, addrs) + + endpoints := cluster.GetLoadAssignment().GetEndpoints()[0].GetLbEndpoints() + if len(endpoints) != 2 { + t.Errorf("Expected 2 LbEndpoints, but got %d", len(endpoints)) + } + out := []string{} + for _, e := range endpoints { + socketAddr, ok := e.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) + if !ok { + continue + } + address := socketAddr.SocketAddress.Address + out = append(out, address) + } + if !slices.Equal(out, addrs) { + t.Errorf("OverwriteDNSCluster error, expected %v, but got %v", out, addrs) + } +} + +// This test aims to evaluate the concurrent writing behavior of the adsCache by utilizing the test race feature. +// The test verifies the ability of the adsCache to handle concurrent access and updates correctly in a multi-goroutine environment. +func TestADSCacheConcurrentWriting(t *testing.T) { + adsCache := NewAdsCache(nil) + cluster := &clusterv3.Cluster{ + Name: "ut-cluster", + ClusterDiscoveryType: &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_LOGICAL_DNS, + }, + } + adsCache.CreateApiClusterByCds(core_v2.ApiStatus_NONE, cluster) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + currentStatus := adsCache.GetApiClusterStatus(cluster.GetName()) + newStatus := currentStatus + core_v2.ApiStatus(rand.Intn(3)-1) + if rand.Intn(2) == 0 { + adsCache.UpdateApiClusterIfExists(newStatus, cluster) + } else { + adsCache.UpdateApiClusterStatus(cluster.GetName(), newStatus) + } + } + }() + } + + wg.Wait() +} + +func TestHandleCdsResponseWithDns(t *testing.T) { + cluster1 := &clusterv3.Cluster{ + Name: "ut-cluster1", + ClusterDiscoveryType: &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_LOGICAL_DNS, + }, + LoadAssignment: &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + { + LbEndpoints: []*endpointv3.LbEndpoint{ + { + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: "foo.bar", + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: uint32(9898), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + cluster2 := &clusterv3.Cluster{ + Name: "ut-cluster2", + ClusterDiscoveryType: &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_STRICT_DNS, + }, + LoadAssignment: &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + { + LbEndpoints: []*endpointv3.LbEndpoint{ + { + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: "foo.baz", + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: uint32(9898), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + testcases := []struct { + name string + clusters []*clusterv3.Cluster + expected []string + }{ + { + name: "add clusters with DNS type", + clusters: []*clusterv3.Cluster{cluster1, cluster2}, + expected: []string{"foo.bar", "foo.baz"}, + }, + { + name: "remove all DNS type clusters", + clusters: []*clusterv3.Cluster{}, + expected: []string{}, + }, + } + + p := NewController(nil).Processor + stopCh := make(chan struct{}) + defer close(stopCh) + dnsResolver, err := NewAdsDnsResolver(p.Cache) + assert.NoError(t, err) + dnsResolver.StartAdsDnsResolver(stopCh) + p.DnsResolverChan = dnsResolver.Clusters + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // notify dns resolver + dnsResolver.Clusters <- tc.clusters + retry.UntilOrFail(t, func() bool { + return slices.EqualUnordered(tc.expected, dnsResolver.dnsResolver.GetAllCachedDomains()) + }, retry.Timeout(1*time.Second)) + }) + } +} + +func TestDNS(t *testing.T) { + fakeDNSServer := dns.NewFakeDNSServer() + + testDNSResolver, err := NewAdsDnsResolver(NewAdsCache(nil)) + if err != nil { + t.Fatal(err) + } + stopCh := make(chan struct{}) + defer close(stopCh) + // testDNSResolver.StartAdsDnsResolver(stopCh) + dnsServer := fakeDNSServer.Server.PacketConn.LocalAddr().String() + testDNSResolver.dnsResolver.ResolvConfServers = []string{dnsServer} + + testCases := []struct { + name string + domain string + refreshRate time.Duration + ttl time.Duration + expected []string + expectedAfterTTL []string + registerDomain func(domain string) + }{ + { + name: "success", + domain: "www.google.com.", + refreshRate: 10 * time.Second, + expected: []string{"10.0.0.1", "fd00::1"}, + registerDomain: func(domain string) { + fakeDNSServer.SetHosts(domain, 1) + }, + }, + { + name: "check dns refresh after ttl, ttl < refreshRate", + domain: "www.bing.com.", + refreshRate: 10 * time.Second, + ttl: 3 * time.Second, + expected: []string{"10.0.0.2", "fd00::2"}, + expectedAfterTTL: []string{"10.0.0.3", "fd00::3"}, + registerDomain: func(domain string) { + fakeDNSServer.SetHosts(domain, 2) + fakeDNSServer.SetTTL(uint32(3)) + time.AfterFunc(time.Second, func() { + fakeDNSServer.SetHosts(domain, 3) + }) + }, + }, + { + name: "check dns refresh after ttl without update bpfmap", + domain: "www.test.com.", + refreshRate: 10 * time.Second, + ttl: 3 * time.Second, + expected: []string{"10.0.0.2", "fd00::2"}, + expectedAfterTTL: []string{"10.0.0.2", "fd00::2"}, + registerDomain: func(domain string) { + fakeDNSServer.SetHosts(domain, 2) + fakeDNSServer.SetTTL(uint32(3)) + }, + }, + { + name: "check dns refresh after refreshRate, ttl > refreshRate", + domain: "www.baidu.com.", + refreshRate: 3 * time.Second, + ttl: 10 * time.Second, + expected: []string{"10.0.0.2", "fd00::2"}, + expectedAfterTTL: []string{"10.0.0.3", "fd00::3"}, + registerDomain: func(domain string) { + fakeDNSServer.SetHosts(domain, 2) + fakeDNSServer.SetTTL(uint32(10)) + time.AfterFunc(time.Second, func() { + fakeDNSServer.SetHosts(domain, 3) + }) + }, + }, + { + name: "failed to resolve", + domain: "www.kmesh.test.", + refreshRate: 10 * time.Second, + expected: []string{}, + }, + } + var wg sync.WaitGroup + for _, testcase := range testCases { + wg.Add(1) + if testcase.registerDomain != nil { + testcase.registerDomain(testcase.domain) + } + + input := &dns.PendingResolveDomain{ + DomainName: testcase.domain, + RefreshRate: testcase.refreshRate, + } + testDNSResolver.dnsResolver.Lock() + testDNSResolver.dnsResolver.Cache[testcase.domain] = &dns.DomainCacheEntry{} + testDNSResolver.dnsResolver.Unlock() + go testDNSResolver.refreshAdsWorker() + + _, ttl, err := testDNSResolver.dnsResolver.Resolve(input.DomainName) + assert.NoError(t, err) + if ttl > input.RefreshRate { + ttl = input.RefreshRate + } + if ttl == 0 { + ttl = dns.DeRefreshInterval + } + testDNSResolver.dnsRefreshQueue.AddAfter(input, ttl) + time.Sleep(2 * time.Second) + + res := testDNSResolver.dnsResolver.GetDNSAddresses(testcase.domain) + if len(res) != 0 || len(testcase.expected) != 0 { + if !reflect.DeepEqual(res, testcase.expected) { + t.Errorf("dns resolve for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expected) + } + + if testcase.expectedAfterTTL != nil { + time.Sleep(ttl + 1) + res = testDNSResolver.dnsResolver.GetDNSAddresses(testcase.domain) + if !reflect.DeepEqual(res, testcase.expectedAfterTTL) { + t.Errorf("dns refresh after ttl failed, for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expectedAfterTTL) + } + } + } + wg.Done() + } + wg.Wait() +} + +func TestGetPendingResolveDomain(t *testing.T) { + utCluster := clusterv3.Cluster{ + Name: "testCluster", + LoadAssignment: &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + { + LbEndpoints: []*endpointv3.LbEndpoint{ + { + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: "192.168.2.1", + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: uint32(9898), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + utClusterWithHost := clusterv3.Cluster{ + Name: "testCluster", + LoadAssignment: &endpointv3.ClusterLoadAssignment{ + Endpoints: []*endpointv3.LocalityLbEndpoints{ + { + LbEndpoints: []*endpointv3.LbEndpoint{ + { + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: "www.google.com", + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: uint32(9898), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + type args struct { + clusters []*clusterv3.Cluster + } + tests := []struct { + name string + args args + want map[string]*dns.PendingResolveDomain + }{ + { + name: "empty domains test", + args: args{ + clusters: []*clusterv3.Cluster{ + &utCluster, + }, + }, + want: map[string]*dns.PendingResolveDomain{}, + }, + { + name: "cluster domain is not IP", + args: args{ + clusters: []*clusterv3.Cluster{ + &utClusterWithHost, + }, + }, + want: map[string]*dns.PendingResolveDomain{ + "www.google.com": { + DomainName: "www.google.com", + Clusters: []*clusterv3.Cluster{&utClusterWithHost}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getPendingResolveDomain(tt.args.clusters); !reflect.DeepEqual(got, tt.want) { + t.Errorf("getPendingResolveDomain() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 2d1805a7f..a6ac6cce9 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -27,11 +27,11 @@ import ( bpfads "kmesh.net/kmesh/pkg/bpf/ads" bpfwl "kmesh.net/kmesh/pkg/bpf/workload" "kmesh.net/kmesh/pkg/constants" + "kmesh.net/kmesh/pkg/controller/ads" "kmesh.net/kmesh/pkg/controller/bypass" "kmesh.net/kmesh/pkg/controller/encryption/ipsec" manage "kmesh.net/kmesh/pkg/controller/manage" "kmesh.net/kmesh/pkg/controller/security" - "kmesh.net/kmesh/pkg/dns" "kmesh.net/kmesh/pkg/kolog" "kmesh.net/kmesh/pkg/kube" "kmesh.net/kmesh/pkg/logger" @@ -156,7 +156,7 @@ func (c *Controller) Start(stopCh <-chan struct{}) error { } if c.client.AdsController != nil { - dnsResolver, err := dns.NewAdsDnsResolver(c.client.AdsController.Processor.Cache) + dnsResolver, err := ads.NewAdsDnsResolver(c.client.AdsController.Processor.Cache) if err != nil { return fmt.Errorf("dns resolver of Kernel-Native mode create failed: %v", err) } diff --git a/pkg/dns/ads_handler_test.go b/pkg/dns/ads_handler_test.go deleted file mode 100644 index 9fc35283e..000000000 --- a/pkg/dns/ads_handler_test.go +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright The Kmesh Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dns - -import ( - "math/rand" - "sync" - "testing" - "time" - - clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" - v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/wrapperspb" - "istio.io/istio/pkg/slices" - "istio.io/istio/pkg/test/util/retry" - - core_v2 "kmesh.net/kmesh/api/v2/core" - "kmesh.net/kmesh/pkg/controller/ads" -) - -type fakeAdsDnsServer struct { -} - -func TestOverwriteDNSCluster(t *testing.T) { - domain := "www.google.com" - addrs := []string{"10.1.1.1", "10.1.1.2"} - cluster := &clusterv3.Cluster{ - Name: "ut-cluster", - ClusterDiscoveryType: &clusterv3.Cluster_Type{ - Type: clusterv3.Cluster_LOGICAL_DNS, - }, - LoadAssignment: &endpointv3.ClusterLoadAssignment{ - ClusterName: "ut-cluster", - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LoadBalancingWeight: wrapperspb.UInt32(30), - Priority: uint32(15), - LbEndpoints: []*endpointv3.LbEndpoint{ - { - HealthStatus: v3.HealthStatus_HEALTHY, - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: domain, - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: uint32(9898), - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - - overwriteDnsCluster(cluster, domain, addrs) - - endpoints := cluster.GetLoadAssignment().GetEndpoints()[0].GetLbEndpoints() - if len(endpoints) != 2 { - t.Errorf("Expected 2 LbEndpoints, but got %d", len(endpoints)) - } - out := []string{} - for _, e := range endpoints { - socketAddr, ok := e.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) - if !ok { - continue - } - address := socketAddr.SocketAddress.Address - out = append(out, address) - } - if !slices.Equal(out, addrs) { - t.Errorf("OverwriteDNSCluster error, expected %v, but got %v", out, addrs) - } -} - -// This test aims to evaluate the concurrent writing behavior of the adsCache by utilizing the test race feature. -// The test verifies the ability of the adsCache to handle concurrent access and updates correctly in a multi-goroutine environment. -func TestADSCacheConcurrentWriting(t *testing.T) { - adsCache := ads.NewAdsCache(nil) - cluster := &clusterv3.Cluster{ - Name: "ut-cluster", - ClusterDiscoveryType: &clusterv3.Cluster_Type{ - Type: clusterv3.Cluster_LOGICAL_DNS, - }, - } - adsCache.CreateApiClusterByCds(core_v2.ApiStatus_NONE, cluster) - - var wg sync.WaitGroup - for i := 0; i < 100; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 100; j++ { - currentStatus := adsCache.GetApiClusterStatus(cluster.GetName()) - newStatus := currentStatus + core_v2.ApiStatus(rand.Intn(3)-1) - if rand.Intn(2) == 0 { - adsCache.UpdateApiClusterIfExists(newStatus, cluster) - } else { - adsCache.UpdateApiClusterStatus(cluster.GetName(), newStatus) - } - } - }() - } - - wg.Wait() -} - -func TestHandleCdsResponseWithDns(t *testing.T) { - cluster1 := &clusterv3.Cluster{ - Name: "ut-cluster1", - ClusterDiscoveryType: &clusterv3.Cluster_Type{ - Type: clusterv3.Cluster_LOGICAL_DNS, - }, - LoadAssignment: &endpointv3.ClusterLoadAssignment{ - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LbEndpoints: []*endpointv3.LbEndpoint{ - { - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: "foo.bar", - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: uint32(9898), - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - cluster2 := &clusterv3.Cluster{ - Name: "ut-cluster2", - ClusterDiscoveryType: &clusterv3.Cluster_Type{ - Type: clusterv3.Cluster_STRICT_DNS, - }, - LoadAssignment: &endpointv3.ClusterLoadAssignment{ - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LbEndpoints: []*endpointv3.LbEndpoint{ - { - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: "foo.baz", - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: uint32(9898), - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - - testcases := []struct { - name string - clusters []*clusterv3.Cluster - expected []string - }{ - { - name: "add clusters with DNS type", - clusters: []*clusterv3.Cluster{cluster1, cluster2}, - expected: []string{"foo.bar", "foo.baz"}, - }, - { - name: "remove all DNS type clusters", - clusters: []*clusterv3.Cluster{}, - expected: []string{}, - }, - } - - p := ads.NewController(nil).Processor - stopCh := make(chan struct{}) - defer close(stopCh) - dnsResolver, err := NewAdsDnsResolver(p.Cache) - assert.NoError(t, err) - dnsResolver.StartAdsDnsResolver(stopCh) - p.DnsResolverChan = dnsResolver.Clusters - for _, tc := range testcases { - t.Run(tc.name, func(t *testing.T) { - // notify dns resolver - dnsResolver.Clusters <- tc.clusters - retry.UntilOrFail(t, func() bool { - return slices.EqualUnordered(tc.expected, dnsResolver.DnsResolver.GetAllCachedDomains()) - }, retry.Timeout(1*time.Second)) - }) - } -} diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index b94a2445c..9eb6bfa3b 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -24,9 +24,7 @@ import ( "time" clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" - "github.com/miekg/dns" - "k8s.io/client-go/util/workqueue" "kmesh.net/kmesh/pkg/logger" ) @@ -43,30 +41,27 @@ const ( type DNSResolver struct { client *dns.Client - resolvConfServers []string - cache map[string]*domainCacheEntry - // dns refresh priority queue based on exp - dnsRefreshQueue workqueue.TypedDelayingInterface[any] + ResolvConfServers []string + Cache map[string]*DomainCacheEntry sync.RWMutex } -type domainCacheEntry struct { +type DomainCacheEntry struct { addresses []string } -// pending resolve domain info, +// pending resolve domain info of Kennel-Native Mode, // domain name is used for dns resolution // cluster is used for create the apicluster -type pendingResolveDomain struct { - domainName string - clusters []*clusterv3.Cluster - refreshRate time.Duration +type PendingResolveDomain struct { + DomainName string + Clusters []*clusterv3.Cluster + RefreshRate time.Duration } func NewDNSResolver() (*DNSResolver, error) { r := &DNSResolver{ - cache: map[string]*domainCacheEntry{}, - dnsRefreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "dnsRefreshQueue"}), + Cache: map[string]*DomainCacheEntry{}, client: &dns.Client{ DialTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second, @@ -80,7 +75,7 @@ func NewDNSResolver() (*DNSResolver, error) { } if dnsConfig != nil { for _, s := range dnsConfig.Servers { - r.resolvConfServers = append(r.resolvConfServers, net.JoinHostPort(s, dnsConfig.Port)) + r.ResolvConfServers = append(r.ResolvConfServers, net.JoinHostPort(s, dnsConfig.Port)) } } @@ -88,69 +83,53 @@ func NewDNSResolver() (*DNSResolver, error) { } // removeUnwatchedDomain cancels any scheduled re-resolve for names we no longer care about -func (r *DNSResolver) removeUnwatchedDomain(domains map[string]*pendingResolveDomain) { +func (r *DNSResolver) RemoveUnwatchedDomain(domains map[string]*PendingResolveDomain) { r.Lock() defer r.Unlock() - for domain := range r.cache { + for domain := range r.Cache { if _, ok := domains[domain]; ok { continue } - delete(r.cache, domain) + delete(r.Cache, domain) } } // This functions were copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. -func (r *DNSResolver) resolve(v *pendingResolveDomain) ([]string, error) { +func (r *DNSResolver) Resolve(domainName string) ([]string, time.Duration, error) { r.RLock() - entry := r.cache[v.domainName] + entry := r.Cache[domainName] // This can happen when the domain is deleted before the refresher tick reaches if entry == nil { r.RUnlock() - return []string{}, fmt.Errorf("cache entry for domain %s not found", v.domainName) + return []string{}, DeRefreshInterval, fmt.Errorf("cache entry for domain %s not found", domainName) } r.RUnlock() - addrs, ttl, err := r.doResolve(v.domainName, v.refreshRate) + addrs, ttl, err := r.doResolve(domainName) if err != nil { - return []string{}, fmt.Errorf("dns resolve failed: %v", err) - } - - if ttl > v.refreshRate { - ttl = v.refreshRate - } - if ttl == 0 { - ttl = DeRefreshInterval + return []string{}, DeRefreshInterval, fmt.Errorf("dns resolve failed: %v", err) } r.RLock() entry.addresses = addrs r.RUnlock() - // push to refresh queue - r.dnsRefreshQueue.AddAfter(v, ttl) - return addrs, nil + return addrs, ttl, nil } // resolveDomains takes a slice of cluster -func (r *DNSResolver) resolveDomains(clusters []*clusterv3.Cluster) { - domains := getPendingResolveDomain(clusters) - - // Stow domain updates, need to remove unwatched domains first - r.removeUnwatchedDomain(domains) - for _, v := range domains { - r.Lock() - if r.cache[v.domainName] == nil { - r.cache[v.domainName] = &domainCacheEntry{} - } - r.Unlock() - r.dnsRefreshQueue.AddAfter(v, 0) +func (r *DNSResolver) ResolveDomains(domainName string) { + r.Lock() + if r.Cache[domainName] == nil { + r.Cache[domainName] = &DomainCacheEntry{} } + r.Unlock() } // doResolve is copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. -func (r *DNSResolver) doResolve(domain string, refreshRate time.Duration) ([]string, time.Duration, error) { +func (r *DNSResolver) doResolve(domain string) ([]string, time.Duration, error) { var out []string - ttl := refreshRate + ttl := DeRefreshInterval var mu sync.Mutex var wg sync.WaitGroup var errs = []error{} @@ -174,7 +153,7 @@ func (r *DNSResolver) doResolve(domain string, refreshRate time.Duration) ([]str out = append(out, record.AAAA.String()) } } - if minTTL := getMinTTL(res, refreshRate); minTTL < ttl { + if minTTL := getMinTTL(res, DeRefreshInterval); minTTL < ttl { ttl = minTTL } } @@ -186,7 +165,7 @@ func (r *DNSResolver) doResolve(domain string, refreshRate time.Duration) ([]str if len(errs) == 2 { // return error only if all requests are failed - return out, refreshRate, fmt.Errorf("upstream dns failure") + return out, DeRefreshInterval, fmt.Errorf("upstream dns failure") } sort.Strings(out) @@ -196,7 +175,7 @@ func (r *DNSResolver) doResolve(domain string, refreshRate time.Duration) ([]str // Query is copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. func (r *DNSResolver) Query(req *dns.Msg) *dns.Msg { var response *dns.Msg - for _, upstream := range r.resolvConfServers { + for _, upstream := range r.ResolvConfServers { resp, _, err := r.client.Exchange(req, upstream) if err != nil || resp == nil { continue @@ -246,3 +225,13 @@ func getMinTTL(m *dns.Msg, refreshRate time.Duration) time.Duration { } return minTTL } + +func (r *DNSResolver) GetAllCachedDomains() []string { + r.RLock() + defer r.RUnlock() + out := make([]string, 0, len(r.Cache)) + for domain := range r.Cache { + out = append(out, domain) + } + return out +} diff --git a/pkg/dns/dns_test.go b/pkg/dns/dns_test.go deleted file mode 100644 index 0cef24d64..000000000 --- a/pkg/dns/dns_test.go +++ /dev/null @@ -1,351 +0,0 @@ -/* - * Copyright The Kmesh Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package dns - -import ( - "fmt" - "math" - "net" - "reflect" - "sync" - "testing" - "time" - - clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" - v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" - "github.com/miekg/dns" - - "kmesh.net/kmesh/pkg/controller/ads" -) - -type fakeDNSServer struct { - *dns.Server - ttl uint32 - failure bool - - mu sync.Mutex - // map fqdn hostname -> ip suffix - hosts map[string]int -} - -func TestDNS(t *testing.T) { - fakeDNSServer := newFakeDNSServer() - - // testDNSResolver, err := NewDNSResolver(ads.NewAdsCache(nil)) - testDNSResolver, err := NewAdsDnsResolver(ads.NewAdsCache(nil)) - if err != nil { - t.Fatal(err) - } - stopCh := make(chan struct{}) - defer close(stopCh) - testDNSResolver.StartAdsDnsResolver(stopCh) - dnsServer := fakeDNSServer.Server.PacketConn.LocalAddr().String() - testDNSResolver.DnsResolver.resolvConfServers = []string{dnsServer} - - testCases := []struct { - name string - domain string - refreshRate time.Duration - ttl time.Duration - expected []string - expectedAfterTTL []string - registerDomain func(domain string) - }{ - { - name: "success", - domain: "www.google.com.", - refreshRate: 10 * time.Second, - expected: []string{"10.0.0.1", "fd00::1"}, - registerDomain: func(domain string) { - fakeDNSServer.setHosts(domain, 1) - }, - }, - { - name: "check dns refresh after ttl, ttl < refreshRate", - domain: "www.bing.com.", - refreshRate: 10 * time.Second, - ttl: 3 * time.Second, - expected: []string{"10.0.0.2", "fd00::2"}, - expectedAfterTTL: []string{"10.0.0.3", "fd00::3"}, - registerDomain: func(domain string) { - fakeDNSServer.setHosts(domain, 2) - fakeDNSServer.setTTL(uint32(3)) - time.AfterFunc(time.Second, func() { - fakeDNSServer.setHosts(domain, 3) - }) - }, - }, - { - name: "check dns refresh after ttl without update bpfmap", - domain: "www.test.com.", - refreshRate: 10 * time.Second, - ttl: 3 * time.Second, - expected: []string{"10.0.0.2", "fd00::2"}, - expectedAfterTTL: []string{"10.0.0.2", "fd00::2"}, - registerDomain: func(domain string) { - fakeDNSServer.setHosts(domain, 2) - fakeDNSServer.setTTL(uint32(3)) - }, - }, - { - name: "check dns refresh after refreshRate, ttl > refreshRate", - domain: "www.baidu.com.", - refreshRate: 3 * time.Second, - ttl: 10 * time.Second, - expected: []string{"10.0.0.2", "fd00::2"}, - expectedAfterTTL: []string{"10.0.0.3", "fd00::3"}, - registerDomain: func(domain string) { - fakeDNSServer.setHosts(domain, 2) - fakeDNSServer.setTTL(uint32(10)) - time.AfterFunc(time.Second, func() { - fakeDNSServer.setHosts(domain, 3) - }) - }, - }, - { - name: "failed to resolve", - domain: "www.kmesh.test.", - refreshRate: 10 * time.Second, - expected: []string{}, - }, - } - var wg sync.WaitGroup - for _, testcase := range testCases { - wg.Add(1) - if testcase.registerDomain != nil { - testcase.registerDomain(testcase.domain) - } - - input := &pendingResolveDomain{ - domainName: testcase.domain, - refreshRate: testcase.refreshRate, - } - testDNSResolver.DnsResolver.Lock() - testDNSResolver.DnsResolver.cache[testcase.domain] = &domainCacheEntry{} - testDNSResolver.DnsResolver.Unlock() - - testDNSResolver.DnsResolver.resolve(input) - - time.Sleep(2 * time.Second) - - res := testDNSResolver.DnsResolver.GetDNSAddresses(testcase.domain) - if len(res) != 0 || len(testcase.expected) != 0 { - if !reflect.DeepEqual(res, testcase.expected) { - t.Errorf("dns resolve for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expected) - } - - if testcase.expectedAfterTTL != nil { - ttl := time.Duration(math.Min(float64(testcase.ttl), float64(testcase.refreshRate))) - time.Sleep(ttl + 1) - res = testDNSResolver.DnsResolver.GetDNSAddresses(testcase.domain) - if !reflect.DeepEqual(res, testcase.expectedAfterTTL) { - t.Errorf("dns refresh after ttl failed, for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expectedAfterTTL) - } - } - } - wg.Done() - } - wg.Wait() -} - -func newFakeDNSServer() *fakeDNSServer { - var wg sync.WaitGroup - wg.Add(1) - s := &fakeDNSServer{ - Server: &dns.Server{Addr: ":0", Net: "udp", NotifyStartedFunc: wg.Done}, - hosts: make(map[string]int), - // default ttl is 20 - ttl: uint32(20), - } - s.Handler = s - - go func() { - if err := s.ListenAndServe(); err != nil { - log.Errorf("fake dns server error: %v", err) - } - }() - wg.Wait() - return s -} - -func (s *fakeDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - s.mu.Lock() - defer s.mu.Unlock() - - msg := (&dns.Msg{}).SetReply(r) - if s.failure { - msg.Rcode = dns.RcodeServerFailure - } else { - domain := msg.Question[0].Name - c, ok := s.hosts[domain] - if ok { - switch r.Question[0].Qtype { - case dns.TypeA: - msg.Answer = append(msg.Answer, &dns.A{ - Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: s.ttl}, - A: net.ParseIP(fmt.Sprintf("10.0.0.%d", c)), - }) - case dns.TypeAAAA: - // set a long TTL for AAAA - msg.Answer = append(msg.Answer, &dns.AAAA{ - Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: s.ttl * 10}, - AAAA: net.ParseIP(fmt.Sprintf("fd00::%x", c)), - }) - // simulate behavior of some public/cloud DNS like Cloudflare or DigitalOcean - case dns.TypeANY: - msg.Rcode = dns.RcodeRefused - default: - msg.Rcode = dns.RcodeNotImplemented - } - } else { - msg.Rcode = dns.RcodeNameError - } - } - if err := w.WriteMsg(msg); err != nil { - log.Errorf("failed writing fake DNS response: %v", err) - } -} - -func (s *fakeDNSServer) setHosts(domain string, surfix int) { - s.mu.Lock() - defer s.mu.Unlock() - s.hosts[dns.Fqdn(domain)] = surfix -} - -func (s *fakeDNSServer) setTTL(ttl uint32) { - s.mu.Lock() - defer s.mu.Unlock() - s.ttl = ttl -} - -func (r *DNSResolver) GetAllCachedDomains() []string { - r.RLock() - defer r.RUnlock() - out := make([]string, 0, len(r.cache)) - for domain := range r.cache { - out = append(out, domain) - } - return out -} - -func (r *DNSResolver) GetDNSAddresses(domain string) []string { - r.Lock() - defer r.Unlock() - if entry, ok := r.cache[domain]; ok { - return entry.addresses - } - return nil -} - -func TestGetPendingResolveDomain(t *testing.T) { - utCluster := clusterv3.Cluster{ - Name: "testCluster", - LoadAssignment: &endpointv3.ClusterLoadAssignment{ - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LbEndpoints: []*endpointv3.LbEndpoint{ - { - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: "192.168.2.1", - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: uint32(9898), - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - - utClusterWithHost := clusterv3.Cluster{ - Name: "testCluster", - LoadAssignment: &endpointv3.ClusterLoadAssignment{ - Endpoints: []*endpointv3.LocalityLbEndpoints{ - { - LbEndpoints: []*endpointv3.LbEndpoint{ - { - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: "www.google.com", - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: uint32(9898), - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - - type args struct { - clusters []*clusterv3.Cluster - } - tests := []struct { - name string - args args - want map[string]*pendingResolveDomain - }{ - { - name: "empty domains test", - args: args{ - clusters: []*clusterv3.Cluster{ - &utCluster, - }, - }, - want: map[string]*pendingResolveDomain{}, - }, - { - name: "cluster domain is not IP", - args: args{ - clusters: []*clusterv3.Cluster{ - &utClusterWithHost, - }, - }, - want: map[string]*pendingResolveDomain{ - "www.google.com": { - domainName: "www.google.com", - clusters: []*clusterv3.Cluster{&utClusterWithHost}, - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getPendingResolveDomain(tt.args.clusters); !reflect.DeepEqual(got, tt.want) { - t.Errorf("getPendingResolveDomain() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/pkg/dns/utils.go b/pkg/dns/utils.go new file mode 100644 index 000000000..5c2ca1ff7 --- /dev/null +++ b/pkg/dns/utils.go @@ -0,0 +1,114 @@ +/* + * Copyright The Kmesh Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dns + +import ( + "fmt" + "net" + "sync" + + "github.com/miekg/dns" +) + +type fakeDNSServer struct { + *dns.Server + ttl uint32 + failure bool + + mu sync.Mutex + // map fqdn hostname -> ip suffix + hosts map[string]int +} + +func NewFakeDNSServer() *fakeDNSServer { + var wg sync.WaitGroup + wg.Add(1) + s := &fakeDNSServer{ + Server: &dns.Server{Addr: ":0", Net: "udp", NotifyStartedFunc: wg.Done}, + hosts: make(map[string]int), + // default ttl is 20 + ttl: uint32(20), + } + s.Handler = s + + go func() { + if err := s.ListenAndServe(); err != nil { + log.Errorf("fake dns server error: %v", err) + } + }() + wg.Wait() + return s +} + +func (s *fakeDNSServer) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + s.mu.Lock() + defer s.mu.Unlock() + + msg := (&dns.Msg{}).SetReply(r) + if s.failure { + msg.Rcode = dns.RcodeServerFailure + } else { + domain := msg.Question[0].Name + c, ok := s.hosts[domain] + if ok { + switch r.Question[0].Qtype { + case dns.TypeA: + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: s.ttl}, + A: net.ParseIP(fmt.Sprintf("10.0.0.%d", c)), + }) + case dns.TypeAAAA: + // set a long TTL for AAAA + msg.Answer = append(msg.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{Name: domain, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: s.ttl * 10}, + AAAA: net.ParseIP(fmt.Sprintf("fd00::%x", c)), + }) + // simulate behavior of some public/cloud DNS like Cloudflare or DigitalOcean + case dns.TypeANY: + msg.Rcode = dns.RcodeRefused + default: + msg.Rcode = dns.RcodeNotImplemented + } + } else { + msg.Rcode = dns.RcodeNameError + } + } + if err := w.WriteMsg(msg); err != nil { + log.Errorf("failed writing fake DNS response: %v", err) + } +} + +func (s *fakeDNSServer) SetHosts(domain string, surfix int) { + s.mu.Lock() + defer s.mu.Unlock() + s.hosts[dns.Fqdn(domain)] = surfix +} + +func (s *fakeDNSServer) SetTTL(ttl uint32) { + s.mu.Lock() + defer s.mu.Unlock() + s.ttl = ttl +} + +func (r *DNSResolver) GetDNSAddresses(domain string) []string { + r.Lock() + defer r.Unlock() + if entry, ok := r.Cache[domain]; ok { + return entry.addresses + } + return nil +} From b9fca75bb002ef256c8957f77c5cceb352fff327 Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Thu, 27 Feb 2025 17:13:15 +0800 Subject: [PATCH 05/11] address comments Signed-off-by: LiZhenCheng9527 --- pkg/controller/ads/dns.go | 132 +++++++++++++++--------------- pkg/controller/ads/dns_test.go | 142 ++------------------------------ pkg/controller/controller.go | 4 +- pkg/dns/dns.go | 112 +++++++++++++++++++------ pkg/dns/dns_test.go | 144 +++++++++++++++++++++++++++++++++ pkg/dns/utils.go | 4 +- 6 files changed, 308 insertions(+), 230 deletions(-) create mode 100644 pkg/dns/dns_test.go diff --git a/pkg/controller/ads/dns.go b/pkg/controller/ads/dns.go index 46b6cac0f..b5e049591 100644 --- a/pkg/controller/ads/dns.go +++ b/pkg/controller/ads/dns.go @@ -20,115 +20,119 @@ import ( "net" "net/netip" "slices" + "time" clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" "google.golang.org/protobuf/types/known/wrapperspb" - "k8s.io/client-go/util/workqueue" core_v2 "kmesh.net/kmesh/api/v2/core" "kmesh.net/kmesh/pkg/dns" ) // adsDnsResolver is DNS resolver of Kernel Native -type AdsDnsResolver struct { - Clusters chan []*clusterv3.Cluster - adsCache *AdsCache - dnsResolver *dns.DNSResolver - dnsRefreshQueue workqueue.TypedDelayingInterface[any] +type dnsController struct { + Clusters chan []*clusterv3.Cluster + cache *AdsCache + dnsResolver *dns.DNSResolver } -func NewAdsDnsResolver(adsCache *AdsCache) (*AdsDnsResolver, error) { +// pending resolve domain info of Kennel-Native Mode, +// domain name is used for dns resolution +// cluster is used for create the apicluster +type pendingResolveDomain struct { + DomainName string + Clusters []*clusterv3.Cluster + RefreshRate time.Duration +} + +func NewDnsResolver(adsCache *AdsCache) (*dnsController, error) { resolver, err := dns.NewDNSResolver() if err != nil { return nil, err } - return &AdsDnsResolver{ - Clusters: make(chan []*clusterv3.Cluster), - dnsRefreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "dnsRefreshQueue"}), - adsCache: adsCache, - dnsResolver: resolver, + return &dnsController{ + Clusters: make(chan []*clusterv3.Cluster), + // dnsRefreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "dnsRefreshQueue"}), + cache: adsCache, + dnsResolver: resolver, }, nil } -func (adsResolver *AdsDnsResolver) StartAdsDnsResolver(stopCh <-chan struct{}) { - go adsResolver.startAdsResolver() - go adsResolver.refreshAdsWorker() +func (r *dnsController) StartKernelNativeDnsController(stopCh <-chan struct{}) { + go r.startDnsController() + // start dns resolver + go r.dnsResolver.StartDnsResolver(stopCh) go func() { <-stopCh - adsResolver.dnsRefreshQueue.ShutDown() - close(adsResolver.Clusters) + close(r.Clusters) }() } -func (adsResolver *AdsDnsResolver) startAdsResolver() { +func (r *dnsController) startDnsController() { rateLimiter := make(chan struct{}, dns.MaxConcurrency) - for clusters := range adsResolver.Clusters { + for clusters := range r.Clusters { rateLimiter <- struct{}{} go func(clusters []*clusterv3.Cluster) { defer func() { <-rateLimiter }() - adsResolver.resolveDomains(clusters) + r.resolveDomains(clusters) }(clusters) } } -func (adsResolver *AdsDnsResolver) refreshAdsDns() bool { - element, quit := adsResolver.dnsRefreshQueue.Get() - if quit { - return false - } - defer adsResolver.dnsRefreshQueue.Done(element) - e := element.(*dns.PendingResolveDomain) - - adsResolver.dnsResolver.RLock() - _, exist := adsResolver.dnsResolver.Cache[e.DomainName] - adsResolver.dnsResolver.RUnlock() - // if the domain is no longer watched, no need to refresh it - if !exist { - return true - } - addresses, ttl, err := adsResolver.dnsResolver.Resolve(e.DomainName) - if err != nil { - log.Errorf("failed to dns resolve: %v", err) - return false - } - if ttl > e.RefreshRate { - ttl = e.RefreshRate - } - if ttl == 0 { - ttl = dns.DeRefreshInterval - } - adsResolver.dnsRefreshQueue.AddAfter(e, ttl) +func (r *dnsController) resolveDomains(cds []*clusterv3.Cluster) { + domains := getPendingResolveDomain(cds) + hostNames := make(map[string]struct{}) - adsResolver.adsDnsResolve(e, addresses) - adsResolver.adsCache.ClusterCache.Flush() - return true -} + for k := range domains { + hostNames[k] = struct{}{} + } -func (adsResolver *AdsDnsResolver) resolveDomains(cds []*clusterv3.Cluster) { - domains := getPendingResolveDomain(cds) + // delete any scheduled re-resolve for domains we no longer care about + r.dnsResolver.RemoveUnwatchDomain(hostNames) + // Directly update the clusters that can find the dns resolution result in the cache + alreadyResolveDomains := r.dnsResolver.GetAddressesFromCache(hostNames) + for k, v := range alreadyResolveDomains { + pendingDomain := domains[k] + r.adsDnsResolve(pendingDomain, v.Addresses) + r.cache.ClusterCache.Flush() + delete(domains, k) + } - // Stow domain updates, need to remove unwatched domains first - adsResolver.dnsResolver.RemoveUnwatchedDomain(domains) for k, v := range domains { - adsResolver.dnsResolver.ResolveDomains(k) - adsResolver.dnsRefreshQueue.AddAfter(v, 0) + r.dnsResolver.ResolveDomains(k) + domainInfo := &dns.DomainInfo{ + Domain: v.DomainName, + RefreshRate: v.RefreshRate, + } + r.dnsResolver.RefreshQueue.AddAfter(domainInfo, 0) } + go r.refreshAdsWorker(domains) } -func (adsResolver *AdsDnsResolver) refreshAdsWorker() { - for adsResolver.refreshAdsDns() { +func (r *dnsController) refreshAdsWorker(domains map[string]*pendingResolveDomain) { + for !(len(domains) == 0) { + domain := <-r.dnsResolver.AdsDnsChan + v, ok := domains[domain] + // will this happen? + if !ok { + continue + } + addresses, _ := r.dnsResolver.GetOneDomainFromCache(domain) + r.adsDnsResolve(v, addresses) + r.cache.ClusterCache.Flush() + delete(domains, domain) } } -func (adsResolver *AdsDnsResolver) adsDnsResolve(pendingDomain *dns.PendingResolveDomain, addrs []string) { +func (r *dnsController) adsDnsResolve(pendingDomain *pendingResolveDomain, addrs []string) { for _, cluster := range pendingDomain.Clusters { ready := overwriteDnsCluster(cluster, pendingDomain.DomainName, addrs) if ready { - if !adsResolver.adsCache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, cluster) { + if !r.cache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, cluster) { log.Debugf("cluster: %s is deleted", cluster.Name) return } @@ -201,8 +205,8 @@ func overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []stri return ready } -func getPendingResolveDomain(cds []*clusterv3.Cluster) map[string]*dns.PendingResolveDomain { - domains := make(map[string]*dns.PendingResolveDomain) +func getPendingResolveDomain(cds []*clusterv3.Cluster) map[string]*pendingResolveDomain { + domains := make(map[string]*pendingResolveDomain) for _, cluster := range cds { if cluster.LoadAssignment == nil { @@ -224,7 +228,7 @@ func getPendingResolveDomain(cds []*clusterv3.Cluster) map[string]*dns.PendingRe if v, ok := domains[address]; ok { v.Clusters = append(v.Clusters, cluster) } else { - domainWithRefreshRate := &dns.PendingResolveDomain{ + domainWithRefreshRate := &pendingResolveDomain{ DomainName: address, Clusters: []*clusterv3.Cluster{cluster}, RefreshRate: cluster.GetDnsRefreshRate().AsDuration(), diff --git a/pkg/controller/ads/dns_test.go b/pkg/controller/ads/dns_test.go index e135e0a05..30a080cfe 100644 --- a/pkg/controller/ads/dns_test.go +++ b/pkg/controller/ads/dns_test.go @@ -32,7 +32,6 @@ import ( "istio.io/istio/pkg/test/util/retry" core_v2 "kmesh.net/kmesh/api/v2/core" - "kmesh.net/kmesh/pkg/dns" ) func TestOverwriteDNSCluster(t *testing.T) { @@ -197,19 +196,14 @@ func TestHandleCdsResponseWithDns(t *testing.T) { clusters: []*clusterv3.Cluster{cluster1, cluster2}, expected: []string{"foo.bar", "foo.baz"}, }, - { - name: "remove all DNS type clusters", - clusters: []*clusterv3.Cluster{}, - expected: []string{}, - }, } p := NewController(nil).Processor stopCh := make(chan struct{}) defer close(stopCh) - dnsResolver, err := NewAdsDnsResolver(p.Cache) + dnsResolver, err := NewDnsResolver(p.Cache) assert.NoError(t, err) - dnsResolver.StartAdsDnsResolver(stopCh) + dnsResolver.StartKernelNativeDnsController(stopCh) p.DnsResolverChan = dnsResolver.Clusters for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { @@ -222,132 +216,6 @@ func TestHandleCdsResponseWithDns(t *testing.T) { } } -func TestDNS(t *testing.T) { - fakeDNSServer := dns.NewFakeDNSServer() - - testDNSResolver, err := NewAdsDnsResolver(NewAdsCache(nil)) - if err != nil { - t.Fatal(err) - } - stopCh := make(chan struct{}) - defer close(stopCh) - // testDNSResolver.StartAdsDnsResolver(stopCh) - dnsServer := fakeDNSServer.Server.PacketConn.LocalAddr().String() - testDNSResolver.dnsResolver.ResolvConfServers = []string{dnsServer} - - testCases := []struct { - name string - domain string - refreshRate time.Duration - ttl time.Duration - expected []string - expectedAfterTTL []string - registerDomain func(domain string) - }{ - { - name: "success", - domain: "www.google.com.", - refreshRate: 10 * time.Second, - expected: []string{"10.0.0.1", "fd00::1"}, - registerDomain: func(domain string) { - fakeDNSServer.SetHosts(domain, 1) - }, - }, - { - name: "check dns refresh after ttl, ttl < refreshRate", - domain: "www.bing.com.", - refreshRate: 10 * time.Second, - ttl: 3 * time.Second, - expected: []string{"10.0.0.2", "fd00::2"}, - expectedAfterTTL: []string{"10.0.0.3", "fd00::3"}, - registerDomain: func(domain string) { - fakeDNSServer.SetHosts(domain, 2) - fakeDNSServer.SetTTL(uint32(3)) - time.AfterFunc(time.Second, func() { - fakeDNSServer.SetHosts(domain, 3) - }) - }, - }, - { - name: "check dns refresh after ttl without update bpfmap", - domain: "www.test.com.", - refreshRate: 10 * time.Second, - ttl: 3 * time.Second, - expected: []string{"10.0.0.2", "fd00::2"}, - expectedAfterTTL: []string{"10.0.0.2", "fd00::2"}, - registerDomain: func(domain string) { - fakeDNSServer.SetHosts(domain, 2) - fakeDNSServer.SetTTL(uint32(3)) - }, - }, - { - name: "check dns refresh after refreshRate, ttl > refreshRate", - domain: "www.baidu.com.", - refreshRate: 3 * time.Second, - ttl: 10 * time.Second, - expected: []string{"10.0.0.2", "fd00::2"}, - expectedAfterTTL: []string{"10.0.0.3", "fd00::3"}, - registerDomain: func(domain string) { - fakeDNSServer.SetHosts(domain, 2) - fakeDNSServer.SetTTL(uint32(10)) - time.AfterFunc(time.Second, func() { - fakeDNSServer.SetHosts(domain, 3) - }) - }, - }, - { - name: "failed to resolve", - domain: "www.kmesh.test.", - refreshRate: 10 * time.Second, - expected: []string{}, - }, - } - var wg sync.WaitGroup - for _, testcase := range testCases { - wg.Add(1) - if testcase.registerDomain != nil { - testcase.registerDomain(testcase.domain) - } - - input := &dns.PendingResolveDomain{ - DomainName: testcase.domain, - RefreshRate: testcase.refreshRate, - } - testDNSResolver.dnsResolver.Lock() - testDNSResolver.dnsResolver.Cache[testcase.domain] = &dns.DomainCacheEntry{} - testDNSResolver.dnsResolver.Unlock() - go testDNSResolver.refreshAdsWorker() - - _, ttl, err := testDNSResolver.dnsResolver.Resolve(input.DomainName) - assert.NoError(t, err) - if ttl > input.RefreshRate { - ttl = input.RefreshRate - } - if ttl == 0 { - ttl = dns.DeRefreshInterval - } - testDNSResolver.dnsRefreshQueue.AddAfter(input, ttl) - time.Sleep(2 * time.Second) - - res := testDNSResolver.dnsResolver.GetDNSAddresses(testcase.domain) - if len(res) != 0 || len(testcase.expected) != 0 { - if !reflect.DeepEqual(res, testcase.expected) { - t.Errorf("dns resolve for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expected) - } - - if testcase.expectedAfterTTL != nil { - time.Sleep(ttl + 1) - res = testDNSResolver.dnsResolver.GetDNSAddresses(testcase.domain) - if !reflect.DeepEqual(res, testcase.expectedAfterTTL) { - t.Errorf("dns refresh after ttl failed, for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expectedAfterTTL) - } - } - } - wg.Done() - } - wg.Wait() -} - func TestGetPendingResolveDomain(t *testing.T) { utCluster := clusterv3.Cluster{ Name: "testCluster", @@ -411,7 +279,7 @@ func TestGetPendingResolveDomain(t *testing.T) { tests := []struct { name string args args - want map[string]*dns.PendingResolveDomain + want map[string]*pendingResolveDomain }{ { name: "empty domains test", @@ -420,7 +288,7 @@ func TestGetPendingResolveDomain(t *testing.T) { &utCluster, }, }, - want: map[string]*dns.PendingResolveDomain{}, + want: map[string]*pendingResolveDomain{}, }, { name: "cluster domain is not IP", @@ -429,7 +297,7 @@ func TestGetPendingResolveDomain(t *testing.T) { &utClusterWithHost, }, }, - want: map[string]*dns.PendingResolveDomain{ + want: map[string]*pendingResolveDomain{ "www.google.com": { DomainName: "www.google.com", Clusters: []*clusterv3.Cluster{&utClusterWithHost}, diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index a6ac6cce9..4845f58e3 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -156,11 +156,11 @@ func (c *Controller) Start(stopCh <-chan struct{}) error { } if c.client.AdsController != nil { - dnsResolver, err := ads.NewAdsDnsResolver(c.client.AdsController.Processor.Cache) + dnsResolver, err := ads.NewDnsResolver(c.client.AdsController.Processor.Cache) if err != nil { return fmt.Errorf("dns resolver of Kernel-Native mode create failed: %v", err) } - dnsResolver.StartAdsDnsResolver(stopCh) + dnsResolver.StartKernelNativeDnsController(stopCh) c.client.AdsController.Processor.DnsResolverChan = dnsResolver.Clusters } diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index 9eb6bfa3b..26d618cd1 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -23,8 +23,8 @@ import ( "sync" "time" - clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" "github.com/miekg/dns" + "k8s.io/client-go/util/workqueue" "kmesh.net/kmesh/pkg/logger" ) @@ -41,32 +41,32 @@ const ( type DNSResolver struct { client *dns.Client + AdsDnsChan chan string ResolvConfServers []string - Cache map[string]*DomainCacheEntry + cache map[string]*DomainCacheEntry + RefreshQueue workqueue.TypedDelayingInterface[any] sync.RWMutex } type DomainCacheEntry struct { - addresses []string + Addresses []string } -// pending resolve domain info of Kennel-Native Mode, -// domain name is used for dns resolution -// cluster is used for create the apicluster -type PendingResolveDomain struct { - DomainName string - Clusters []*clusterv3.Cluster +type DomainInfo struct { + Domain string RefreshRate time.Duration } func NewDNSResolver() (*DNSResolver, error) { r := &DNSResolver{ - Cache: map[string]*DomainCacheEntry{}, + AdsDnsChan: make(chan string, 100), + cache: map[string]*DomainCacheEntry{}, client: &dns.Client{ DialTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, }, + RefreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "RefreshQueue"}), } dnsConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf") @@ -82,22 +82,52 @@ func NewDNSResolver() (*DNSResolver, error) { return r, nil } -// removeUnwatchedDomain cancels any scheduled re-resolve for names we no longer care about -func (r *DNSResolver) RemoveUnwatchedDomain(domains map[string]*PendingResolveDomain) { - r.Lock() - defer r.Unlock() - for domain := range r.Cache { - if _, ok := domains[domain]; ok { - continue +func (r *DNSResolver) StartDnsResolver(stop <-chan struct{}) { + for { + select { + case <-stop: + r.RefreshQueue.ShutDown() + return + default: + r.refreshDns() } - delete(r.Cache, domain) } } +func (r *DNSResolver) refreshDns() { + element, quit := r.RefreshQueue.Get() + if quit { + return + } + defer r.RefreshQueue.Done(element) + e := element.(*DomainInfo) + + r.Lock() + _, exist := r.cache[e.Domain] + r.Unlock() + // if the domain is no longer watched, no need to refresh it + if !exist { + return + } + _, ttl, err := r.resolve(e.Domain) + if err != nil { + log.Errorf("failed to dns resolve: %v", err) + return + } + if ttl > e.RefreshRate { + ttl = e.RefreshRate + } + if ttl == 0 { + ttl = DeRefreshInterval + } + r.RefreshQueue.AddAfter(e, ttl) + r.AdsDnsChan <- e.Domain +} + // This functions were copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. -func (r *DNSResolver) Resolve(domainName string) ([]string, time.Duration, error) { +func (r *DNSResolver) resolve(domainName string) ([]string, time.Duration, error) { r.RLock() - entry := r.Cache[domainName] + entry := r.cache[domainName] // This can happen when the domain is deleted before the refresher tick reaches if entry == nil { r.RUnlock() @@ -111,7 +141,7 @@ func (r *DNSResolver) Resolve(domainName string) ([]string, time.Duration, error } r.RLock() - entry.addresses = addrs + entry.Addresses = addrs r.RUnlock() return addrs, ttl, nil @@ -120,8 +150,8 @@ func (r *DNSResolver) Resolve(domainName string) ([]string, time.Duration, error // resolveDomains takes a slice of cluster func (r *DNSResolver) ResolveDomains(domainName string) { r.Lock() - if r.Cache[domainName] == nil { - r.Cache[domainName] = &DomainCacheEntry{} + if r.cache[domainName] == nil { + r.cache[domainName] = &DomainCacheEntry{} } r.Unlock() } @@ -229,9 +259,41 @@ func getMinTTL(m *dns.Msg, refreshRate time.Duration) time.Duration { func (r *DNSResolver) GetAllCachedDomains() []string { r.RLock() defer r.RUnlock() - out := make([]string, 0, len(r.Cache)) - for domain := range r.Cache { + out := make([]string, 0, len(r.cache)) + for domain := range r.cache { out = append(out, domain) } return out } + +func (r *DNSResolver) GetOneDomainFromCache(domain string) ([]string, bool) { + r.Lock() + addresses, ok := r.cache[domain] + r.Unlock() + return addresses.Addresses, ok +} + +func (r *DNSResolver) GetAddressesFromCache(domains map[string]struct{}) map[string]*DomainCacheEntry { + r.Lock() + defer r.Unlock() + + alreadyResolveDomains := make(map[string]*DomainCacheEntry) + for domain := range domains { + if v, ok := r.cache[domain]; ok { + alreadyResolveDomains[domain] = v + } + } + return alreadyResolveDomains +} + +func (r *DNSResolver) RemoveUnwatchDomain(domains map[string]struct{}) { + r.Lock() + defer r.Unlock() + + for domain := range r.cache { + if _, ok := domains[domain]; ok { + continue + } + delete(r.cache, domain) + } +} diff --git a/pkg/dns/dns_test.go b/pkg/dns/dns_test.go new file mode 100644 index 000000000..43a266553 --- /dev/null +++ b/pkg/dns/dns_test.go @@ -0,0 +1,144 @@ +/* + * Copyright The Kmesh Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dns + +import ( + "math" + "reflect" + "sync" + "testing" + "time" +) + +func TestDNS(t *testing.T) { + fakeDNSServer := NewFakeDNSServer() + + testDNSResolver, err := NewDNSResolver() + if err != nil { + t.Fatal(err) + } + stopCh := make(chan struct{}) + defer close(stopCh) + // testDNSResolver.StartAdsDnsResolver(stopCh) + dnsServer := fakeDNSServer.Server.PacketConn.LocalAddr().String() + testDNSResolver.ResolvConfServers = []string{dnsServer} + go testDNSResolver.StartDnsResolver(stopCh) + + testCases := []struct { + name string + domain string + refreshRate time.Duration + ttl time.Duration + expected []string + expectedAfterTTL []string + registerDomain func(domain string) + }{ + { + name: "success", + domain: "www.google.com.", + refreshRate: 10 * time.Second, + expected: []string{"10.0.0.1", "fd00::1"}, + registerDomain: func(domain string) { + fakeDNSServer.SetHosts(domain, 1) + }, + }, + { + name: "check dns refresh after ttl, ttl < refreshRate", + domain: "www.bing.com.", + refreshRate: 10 * time.Second, + ttl: 3 * time.Second, + expected: []string{"10.0.0.2", "fd00::2"}, + expectedAfterTTL: []string{"10.0.0.3", "fd00::3"}, + registerDomain: func(domain string) { + fakeDNSServer.SetHosts(domain, 2) + fakeDNSServer.SetTTL(uint32(3)) + time.AfterFunc(time.Second, func() { + fakeDNSServer.SetHosts(domain, 3) + }) + }, + }, + { + name: "check dns refresh after ttl without update bpfmap", + domain: "www.test.com.", + refreshRate: 10 * time.Second, + ttl: 3 * time.Second, + expected: []string{"10.0.0.2", "fd00::2"}, + expectedAfterTTL: []string{"10.0.0.2", "fd00::2"}, + registerDomain: func(domain string) { + fakeDNSServer.SetHosts(domain, 2) + fakeDNSServer.SetTTL(uint32(3)) + }, + }, + { + name: "check dns refresh after refreshRate, ttl > refreshRate", + domain: "www.baidu.com.", + refreshRate: 3 * time.Second, + ttl: 10 * time.Second, + expected: []string{"10.0.0.2", "fd00::2"}, + expectedAfterTTL: []string{"10.0.0.3", "fd00::3"}, + registerDomain: func(domain string) { + fakeDNSServer.SetHosts(domain, 2) + fakeDNSServer.SetTTL(uint32(10)) + time.AfterFunc(time.Second, func() { + fakeDNSServer.SetHosts(domain, 3) + }) + }, + }, + { + name: "failed to resolve", + domain: "www.kmesh.test.", + refreshRate: 10 * time.Second, + expected: []string{}, + }, + } + var wg sync.WaitGroup + for _, testcase := range testCases { + wg.Add(1) + if testcase.registerDomain != nil { + testcase.registerDomain(testcase.domain) + } + + input := &DomainInfo{ + Domain: testcase.domain, + RefreshRate: testcase.refreshRate, + } + testDNSResolver.Lock() + testDNSResolver.cache[testcase.domain] = &DomainCacheEntry{} + testDNSResolver.Unlock() + testDNSResolver.RefreshQueue.AddAfter(input, 0) + + time.Sleep(2 * time.Second) + + res := testDNSResolver.GetDNSAddresses(testcase.domain) + if len(res) != 0 || len(testcase.expected) != 0 { + if !reflect.DeepEqual(res, testcase.expected) { + t.Errorf("dns resolve for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expected) + } + + if testcase.expectedAfterTTL != nil { + ttl := time.Duration(math.Min(float64(testcase.ttl), float64(testcase.refreshRate))) + time.Sleep(ttl + 1) + res = testDNSResolver.GetDNSAddresses(testcase.domain) + if !reflect.DeepEqual(res, testcase.expectedAfterTTL) { + t.Errorf("dns refresh after ttl failed, for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expectedAfterTTL) + } + } + } + wg.Done() + } + wg.Wait() +} diff --git a/pkg/dns/utils.go b/pkg/dns/utils.go index 5c2ca1ff7..1aceac105 100644 --- a/pkg/dns/utils.go +++ b/pkg/dns/utils.go @@ -107,8 +107,8 @@ func (s *fakeDNSServer) SetTTL(ttl uint32) { func (r *DNSResolver) GetDNSAddresses(domain string) []string { r.Lock() defer r.Unlock() - if entry, ok := r.Cache[domain]; ok { - return entry.addresses + if entry, ok := r.cache[domain]; ok { + return entry.Addresses } return nil } From b9ce6317c0b2ed3f4c44113315b8f641452e8e8f Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Thu, 6 Mar 2025 15:00:45 +0800 Subject: [PATCH 06/11] Minimize the number of public elements in the DNSResolver Signed-off-by: LiZhenCheng9527 --- pkg/controller/ads/dns.go | 7 +++---- pkg/dns/dns.go | 33 ++++++++++++++++++++------------- pkg/dns/dns_test.go | 4 ++-- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pkg/controller/ads/dns.go b/pkg/controller/ads/dns.go index b5e049591..245765bd8 100644 --- a/pkg/controller/ads/dns.go +++ b/pkg/controller/ads/dns.go @@ -53,8 +53,7 @@ func NewDnsResolver(adsCache *AdsCache) (*dnsController, error) { return nil, err } return &dnsController{ - Clusters: make(chan []*clusterv3.Cluster), - // dnsRefreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "dnsRefreshQueue"}), + Clusters: make(chan []*clusterv3.Cluster), cache: adsCache, dnsResolver: resolver, }, nil @@ -108,14 +107,14 @@ func (r *dnsController) resolveDomains(cds []*clusterv3.Cluster) { Domain: v.DomainName, RefreshRate: v.RefreshRate, } - r.dnsResolver.RefreshQueue.AddAfter(domainInfo, 0) + r.dnsResolver.AddDomainIntoRefreshQueue(domainInfo, 0) } go r.refreshAdsWorker(domains) } func (r *dnsController) refreshAdsWorker(domains map[string]*pendingResolveDomain) { for !(len(domains) == 0) { - domain := <-r.dnsResolver.AdsDnsChan + domain := <-r.dnsResolver.DnsChan v, ok := domains[domain] // will this happen? if !ok { diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index 26d618cd1..0d08255fe 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -41,10 +41,10 @@ const ( type DNSResolver struct { client *dns.Client - AdsDnsChan chan string - ResolvConfServers []string + DnsChan chan string + resolvConfServers []string cache map[string]*DomainCacheEntry - RefreshQueue workqueue.TypedDelayingInterface[any] + refreshQueue workqueue.TypedDelayingInterface[any] sync.RWMutex } @@ -59,14 +59,14 @@ type DomainInfo struct { func NewDNSResolver() (*DNSResolver, error) { r := &DNSResolver{ - AdsDnsChan: make(chan string, 100), - cache: map[string]*DomainCacheEntry{}, + DnsChan: make(chan string, 100), + cache: map[string]*DomainCacheEntry{}, client: &dns.Client{ DialTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, }, - RefreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "RefreshQueue"}), + refreshQueue: workqueue.NewTypedDelayingQueueWithConfig(workqueue.TypedDelayingQueueConfig[any]{Name: "refreshQueue"}), } dnsConfig, err := dns.ClientConfigFromFile("/etc/resolv.conf") @@ -75,7 +75,7 @@ func NewDNSResolver() (*DNSResolver, error) { } if dnsConfig != nil { for _, s := range dnsConfig.Servers { - r.ResolvConfServers = append(r.ResolvConfServers, net.JoinHostPort(s, dnsConfig.Port)) + r.resolvConfServers = append(r.resolvConfServers, net.JoinHostPort(s, dnsConfig.Port)) } } @@ -86,7 +86,7 @@ func (r *DNSResolver) StartDnsResolver(stop <-chan struct{}) { for { select { case <-stop: - r.RefreshQueue.ShutDown() + r.refreshQueue.ShutDown() return default: r.refreshDns() @@ -95,11 +95,11 @@ func (r *DNSResolver) StartDnsResolver(stop <-chan struct{}) { } func (r *DNSResolver) refreshDns() { - element, quit := r.RefreshQueue.Get() + element, quit := r.refreshQueue.Get() if quit { return } - defer r.RefreshQueue.Done(element) + defer r.refreshQueue.Done(element) e := element.(*DomainInfo) r.Lock() @@ -120,8 +120,8 @@ func (r *DNSResolver) refreshDns() { if ttl == 0 { ttl = DeRefreshInterval } - r.RefreshQueue.AddAfter(e, ttl) - r.AdsDnsChan <- e.Domain + r.refreshQueue.AddAfter(e, ttl) + r.DnsChan <- e.Domain } // This functions were copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. @@ -205,7 +205,7 @@ func (r *DNSResolver) doResolve(domain string) ([]string, time.Duration, error) // Query is copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. func (r *DNSResolver) Query(req *dns.Msg) *dns.Msg { var response *dns.Msg - for _, upstream := range r.ResolvConfServers { + for _, upstream := range r.resolvConfServers { resp, _, err := r.client.Exchange(req, upstream) if err != nil || resp == nil { continue @@ -297,3 +297,10 @@ func (r *DNSResolver) RemoveUnwatchDomain(domains map[string]struct{}) { delete(r.cache, domain) } } + +func (r *DNSResolver) AddDomainIntoRefreshQueue(info *DomainInfo, time time.Duration) { + if info == nil { + return + } + r.refreshQueue.AddAfter(info, time) +} diff --git a/pkg/dns/dns_test.go b/pkg/dns/dns_test.go index 43a266553..4cb9523d9 100644 --- a/pkg/dns/dns_test.go +++ b/pkg/dns/dns_test.go @@ -35,7 +35,7 @@ func TestDNS(t *testing.T) { defer close(stopCh) // testDNSResolver.StartAdsDnsResolver(stopCh) dnsServer := fakeDNSServer.Server.PacketConn.LocalAddr().String() - testDNSResolver.ResolvConfServers = []string{dnsServer} + testDNSResolver.resolvConfServers = []string{dnsServer} go testDNSResolver.StartDnsResolver(stopCh) testCases := []struct { @@ -119,7 +119,7 @@ func TestDNS(t *testing.T) { testDNSResolver.Lock() testDNSResolver.cache[testcase.domain] = &DomainCacheEntry{} testDNSResolver.Unlock() - testDNSResolver.RefreshQueue.AddAfter(input, 0) + testDNSResolver.refreshQueue.AddAfter(input, 0) time.Sleep(2 * time.Second) From f4397eaaba03aa1e0608e715d113c8d4b4836bcd Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Wed, 19 Mar 2025 19:01:22 +0800 Subject: [PATCH 07/11] Optimized cds refresh processing Signed-off-by: LiZhenCheng9527 --- pkg/controller/ads/ads_controller.go | 25 ++- pkg/controller/ads/ads_controller_test.go | 4 +- pkg/controller/ads/dns.go | 248 ++++++++++++++-------- pkg/controller/ads/dns_test.go | 59 +++-- pkg/controller/controller.go | 12 +- pkg/dns/dns.go | 17 +- 6 files changed, 232 insertions(+), 133 deletions(-) diff --git a/pkg/controller/ads/ads_controller.go b/pkg/controller/ads/ads_controller.go index cc993f319..cd84d8e9a 100644 --- a/pkg/controller/ads/ads_controller.go +++ b/pkg/controller/ads/ads_controller.go @@ -33,8 +33,9 @@ var ( ) type Controller struct { - Processor *processor - con *connection + Processor *processor + dnsResolverController *dnsController + con *connection } type connection struct { @@ -44,8 +45,17 @@ type connection struct { } func NewController(bpfAds *bpfads.BpfAds) *Controller { + processor := newProcessor(bpfAds) + // create kernel-native mode ads resolver controller + dnsResolverController, err := NewDnsResolver(processor.Cache) + if err != nil { + log.Errorf("dns resolver of Kernel-Native mode create failed: %v", err) + } + processor.DnsResolverChan = dnsResolverController.Clusters + return &Controller{ - Processor: newProcessor(bpfAds), + dnsResolverController: dnsResolverController, + Processor: processor, } } @@ -84,6 +94,9 @@ func (c *Controller) HandleAdsStream() error { return fmt.Errorf("stream recv failed, %s", err) } + // Because Kernel-Native mode is full update. + // So the original clusterCache is deleted when a new resp is received. + c.dnsResolverController.newClusterCache() c.Processor.processAdsResponse(rsp) c.con.requestsChan.Put(c.Processor.ack) if c.Processor.req != nil { @@ -115,3 +128,9 @@ func (c *Controller) Close() { _ = c.con.Stream.CloseSend() } } + +func (c *Controller) StartDnsController(stopCh <-chan struct{}) { + if c.dnsResolverController != nil { + c.dnsResolverController.Run(stopCh) + } +} diff --git a/pkg/controller/ads/ads_controller_test.go b/pkg/controller/ads/ads_controller_test.go index 70c4a36d6..f6aa72775 100644 --- a/pkg/controller/ads/ads_controller_test.go +++ b/pkg/controller/ads/ads_controller_test.go @@ -119,9 +119,8 @@ func TestHandleAdsStream(t *testing.T) { adsStream := NewController(nil) adsStream.con = &connection{Stream: fakeClient.AdsClient, requestsChan: channels.NewUnbounded[*service_discovery_v3.DiscoveryRequest](), stopCh: make(chan struct{})} - + adsStream.dnsResolverController.Run(make(chan struct{})) patches1 := gomonkey.NewPatches() - patches2 := gomonkey.NewPatches() tests := []struct { name string beforeFunc func() @@ -161,7 +160,6 @@ func TestHandleAdsStream(t *testing.T) { }, afterFunc: func() { patches1.Reset() - patches2.Reset() }, wantErr: false, }, diff --git a/pkg/controller/ads/dns.go b/pkg/controller/ads/dns.go index 245765bd8..0cad2b1bc 100644 --- a/pkg/controller/ads/dns.go +++ b/pkg/controller/ads/dns.go @@ -17,14 +17,17 @@ package ads import ( + "fmt" "net" "net/netip" "slices" + "sync" "time" clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/wrapperspb" core_v2 "kmesh.net/kmesh/api/v2/core" @@ -36,6 +39,11 @@ type dnsController struct { Clusters chan []*clusterv3.Cluster cache *AdsCache dnsResolver *dns.DNSResolver + // Store the copy of pendingResolveDomain. + clusterCache map[string]*pendingResolveDomain + // store all pending hostnames in the clusters + pendingClusterInfo map[string][]string + sync.RWMutex } // pending resolve domain info of Kennel-Native Mode, @@ -53,16 +61,19 @@ func NewDnsResolver(adsCache *AdsCache) (*dnsController, error) { return nil, err } return &dnsController{ - Clusters: make(chan []*clusterv3.Cluster), - cache: adsCache, - dnsResolver: resolver, + Clusters: make(chan []*clusterv3.Cluster), + cache: adsCache, + dnsResolver: resolver, + clusterCache: make(map[string]*pendingResolveDomain), + pendingClusterInfo: make(map[string][]string), }, nil } -func (r *dnsController) StartKernelNativeDnsController(stopCh <-chan struct{}) { - go r.startDnsController() +func (r *dnsController) Run(stopCh <-chan struct{}) { // start dns resolver go r.dnsResolver.StartDnsResolver(stopCh) + go r.refreshAdsWorker(stopCh) + go r.startDnsController() go func() { <-stopCh close(r.Clusters) @@ -70,68 +81,63 @@ func (r *dnsController) StartKernelNativeDnsController(stopCh <-chan struct{}) { } func (r *dnsController) startDnsController() { - rateLimiter := make(chan struct{}, dns.MaxConcurrency) for clusters := range r.Clusters { - rateLimiter <- struct{}{} - go func(clusters []*clusterv3.Cluster) { - defer func() { - <-rateLimiter - }() - r.resolveDomains(clusters) - }(clusters) + r.resolveDomains(clusters) } } func (r *dnsController) resolveDomains(cds []*clusterv3.Cluster) { - domains := getPendingResolveDomain(cds) - hostNames := make(map[string]struct{}) + domains, hostNames := getPendingResolveDomain(cds) - for k := range domains { - hostNames[k] = struct{}{} + // store all pending hostnames of clusters in r.hostInfo + for _, cluster := range cds { + clusterName := cluster.GetName() + info := getHostInfo(cluster) + r.pendingClusterInfo[clusterName] = info } // delete any scheduled re-resolve for domains we no longer care about r.dnsResolver.RemoveUnwatchDomain(hostNames) - // Directly update the clusters that can find the dns resolution result in the cache - alreadyResolveDomains := r.dnsResolver.GetAddressesFromCache(hostNames) - for k, v := range alreadyResolveDomains { - pendingDomain := domains[k] - r.adsDnsResolve(pendingDomain, v.Addresses) - r.cache.ClusterCache.Flush() - delete(domains, k) - } for k, v := range domains { - r.dnsResolver.ResolveDomains(k) - domainInfo := &dns.DomainInfo{ - Domain: v.DomainName, - RefreshRate: v.RefreshRate, + addresses := r.dnsResolver.GetDNSAddresses(k) + // Already have record in dns cache + if addresses != nil { + r.updateClusters(v, addresses) + go r.cache.ClusterCache.Flush() + } else { + r.dnsResolver.InitializeDomainInCache(k) + domainInfo := &dns.DomainInfo{ + Domain: v.DomainName, + RefreshRate: v.RefreshRate, + } + r.dnsResolver.ScheduleDomainRefresh(domainInfo, 0) } - r.dnsResolver.AddDomainIntoRefreshQueue(domainInfo, 0) } - go r.refreshAdsWorker(domains) } -func (r *dnsController) refreshAdsWorker(domains map[string]*pendingResolveDomain) { - for !(len(domains) == 0) { - domain := <-r.dnsResolver.DnsChan - v, ok := domains[domain] - // will this happen? - if !ok { - continue +func (r *dnsController) refreshAdsWorker(stop <-chan struct{}) { + for { + select { + case <-stop: + return + default: + domain := <-r.dnsResolver.DnsChan + pendingDomain := r.getClustersByDomain(domain) + addrs := r.dnsResolver.GetDNSAddresses(domain) + r.updateClusters(pendingDomain, addrs) } - addresses, _ := r.dnsResolver.GetOneDomainFromCache(domain) - r.adsDnsResolve(v, addresses) - r.cache.ClusterCache.Flush() - delete(domains, domain) } } -func (r *dnsController) adsDnsResolve(pendingDomain *pendingResolveDomain, addrs []string) { +func (r *dnsController) updateClusters(pendingDomain *pendingResolveDomain, addrs []string) { + if pendingDomain == nil || addrs == nil { + return + } for _, cluster := range pendingDomain.Clusters { - ready := overwriteDnsCluster(cluster, pendingDomain.DomainName, addrs) + ready, newCluster := r.overwriteDnsCluster(cluster, pendingDomain.DomainName, addrs) if ready { - if !r.cache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, cluster) { + if !r.cache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, newCluster) { log.Debugf("cluster: %s is deleted", cluster.Name) return } @@ -139,73 +145,105 @@ func (r *dnsController) adsDnsResolve(pendingDomain *pendingResolveDomain, addrs } } -func overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []string) bool { - buildLbEndpoints := func(port uint32) []*endpointv3.LbEndpoint { - lbEndpoints := make([]*endpointv3.LbEndpoint, 0, len(addrs)) - for _, addr := range addrs { - ip := net.ParseIP(addr) - if ip == nil { - continue - } - if ip.To4() == nil { - continue +func (r *dnsController) overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []string) (bool, *clusterv3.Cluster) { + ready := true + hostNames := r.pendingClusterInfo[cluster.GetName()] + addressesOfHostname := make(map[string][]string) + + for _, hostName := range hostNames { + addresses := r.dnsResolver.GetDNSAddresses(hostName) + // There are hostnames in this Cluster that are not resolved. + if addresses != nil { + addressesOfHostname[hostName] = addresses + } else { + ready = false + } + } + + if ready { + newCluster := cloneCluster(cluster) + for _, e := range newCluster.LoadAssignment.Endpoints { + pos := -1 + var lbEndpoints []*endpointv3.LbEndpoint + for i, le := range e.LbEndpoints { + socketAddr, ok := le.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) + if !ok { + continue + } + _, err := netip.ParseAddr(socketAddr.SocketAddress.Address) + if err != nil { + host := socketAddr.SocketAddress.Address + addresses := addressesOfHostname[host] + fmt.Printf("addresses %#v", addresses) + pos = i + lbEndpoints = buildLbEndpoints(socketAddr.SocketAddress.GetPortValue(), addresses) + } } - lbEndpoint := &endpointv3.LbEndpoint{ - HealthStatus: v3.HealthStatus_HEALTHY, - HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ - Endpoint: &endpointv3.Endpoint{ - Address: &v3.Address{ - Address: &v3.Address_SocketAddress{ - SocketAddress: &v3.SocketAddress{ - Address: addr, - PortSpecifier: &v3.SocketAddress_PortValue{ - PortValue: port, - }, + e.LbEndpoints = slices.Replace(e.LbEndpoints, pos, pos+1, lbEndpoints...) + } + return ready, newCluster + } + + return ready, nil +} + +func buildLbEndpoints(port uint32, addrs []string) []*endpointv3.LbEndpoint { + lbEndpoints := make([]*endpointv3.LbEndpoint, 0, len(addrs)) + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip == nil { + continue + } + if ip.To4() == nil { + continue + } + lbEndpoint := &endpointv3.LbEndpoint{ + HealthStatus: v3.HealthStatus_HEALTHY, + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &v3.Address{ + Address: &v3.Address_SocketAddress{ + SocketAddress: &v3.SocketAddress{ + Address: addr, + PortSpecifier: &v3.SocketAddress_PortValue{ + PortValue: port, }, }, }, }, }, - // TODO: support LoadBalancingWeight - LoadBalancingWeight: &wrapperspb.UInt32Value{ - Value: 1, - }, - } - lbEndpoints = append(lbEndpoints, lbEndpoint) + }, + // TODO: support LoadBalancingWeight + LoadBalancingWeight: &wrapperspb.UInt32Value{ + Value: 1, + }, } - return lbEndpoints + lbEndpoints = append(lbEndpoints, lbEndpoint) } + return lbEndpoints +} - ready := true +func getHostInfo(cluster *clusterv3.Cluster) []string { + info := []string{} for _, e := range cluster.LoadAssignment.Endpoints { - pos := -1 - var lbEndpoints []*endpointv3.LbEndpoint - for i, le := range e.LbEndpoints { + for _, le := range e.LbEndpoints { socketAddr, ok := le.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) if !ok { continue } _, err := netip.ParseAddr(socketAddr.SocketAddress.Address) if err != nil { - if socketAddr.SocketAddress.Address == domain { - pos = i - lbEndpoints = buildLbEndpoints(socketAddr.SocketAddress.GetPortValue()) - } else { - // There is other domains not resolved for this cluster - ready = false - } + info = append(info, socketAddr.SocketAddress.Address) } } - if pos >= 0 { - e.LbEndpoints = slices.Replace(e.LbEndpoints, pos, pos+1, lbEndpoints...) - } } - return ready + return info } -func getPendingResolveDomain(cds []*clusterv3.Cluster) map[string]*pendingResolveDomain { +func getPendingResolveDomain(cds []*clusterv3.Cluster) (map[string]*pendingResolveDomain, map[string]struct{}) { domains := make(map[string]*pendingResolveDomain) + hostNames := make(map[string]struct{}) for _, cluster := range cds { if cluster.LoadAssignment == nil { @@ -233,10 +271,40 @@ func getPendingResolveDomain(cds []*clusterv3.Cluster) map[string]*pendingResolv RefreshRate: cluster.GetDnsRefreshRate().AsDuration(), } domains[address] = domainWithRefreshRate + hostNames[address] = struct{}{} } } } } - return domains + return domains, hostNames +} + +func (r *dnsController) newClusterCache() { + if r.clusterCache != nil { + r.Lock() + defer r.Unlock() + log.Debug("clean up dns clusters") + r.clusterCache = map[string]*pendingResolveDomain{} + return + } +} + +func (r *dnsController) getClustersByDomain(domain string) *pendingResolveDomain { + if r.clusterCache != nil { + r.RLock() + defer r.RUnlock() + if v, ok := r.clusterCache[domain]; ok { + return v + } + } + return nil +} + +func cloneCluster(cluster *clusterv3.Cluster) *clusterv3.Cluster { + if cluster == nil { + return nil + } + clusterCopy := proto.Clone(cluster).(*clusterv3.Cluster) + return clusterCopy } diff --git a/pkg/controller/ads/dns_test.go b/pkg/controller/ads/dns_test.go index 30a080cfe..ba6d43ac5 100644 --- a/pkg/controller/ads/dns_test.go +++ b/pkg/controller/ads/dns_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/agiledragon/gomonkey/v2" clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" @@ -32,11 +33,12 @@ import ( "istio.io/istio/pkg/test/util/retry" core_v2 "kmesh.net/kmesh/api/v2/core" + "kmesh.net/kmesh/pkg/dns" ) func TestOverwriteDNSCluster(t *testing.T) { domain := "www.google.com" - addrs := []string{"10.1.1.1", "10.1.1.2"} + addrs := []string{"10.1.1.1", "10.1.1.2", "10.1.1.3"} cluster := &clusterv3.Cluster{ Name: "ut-cluster", ClusterDiscoveryType: &clusterv3.Cluster_Type{ @@ -72,23 +74,44 @@ func TestOverwriteDNSCluster(t *testing.T) { }, } - overwriteDnsCluster(cluster, domain, addrs) - - endpoints := cluster.GetLoadAssignment().GetEndpoints()[0].GetLbEndpoints() - if len(endpoints) != 2 { - t.Errorf("Expected 2 LbEndpoints, but got %d", len(endpoints)) + p := NewController(nil).Processor + stopCh := make(chan struct{}) + defer close(stopCh) + dnsResolver, err := NewDnsResolver(p.Cache) + assert.NoError(t, err) + p.DnsResolverChan = dnsResolver.Clusters + dnsResolver.pendingClusterInfo = map[string][]string{ + cluster.GetName(): []string{ + domain, + }, } - out := []string{} - for _, e := range endpoints { - socketAddr, ok := e.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) - if !ok { - continue + patches := gomonkey.NewPatches() + defer patches.Reset() + patches.ApplyMethod(reflect.TypeOf(dnsResolver.dnsResolver), "GetDNSAddresses", + func(_ *dns.DNSResolver, name string) []string { + return addrs + }) + + ready, newCluster := dnsResolver.overwriteDnsCluster(cluster, domain, addrs) + assert.Equal(t, true, ready) + + if ready { + endpoints := newCluster.GetLoadAssignment().GetEndpoints()[0].GetLbEndpoints() + if len(endpoints) != 3 { + t.Errorf("Expected 3 LbEndpoints, but got %d", len(endpoints)) + } + out := []string{} + for _, e := range endpoints { + socketAddr, ok := e.GetEndpoint().GetAddress().GetAddress().(*v3.Address_SocketAddress) + if !ok { + continue + } + address := socketAddr.SocketAddress.Address + out = append(out, address) + } + if !slices.Equal(out, addrs) { + t.Errorf("OverwriteDNSCluster error, expected %v, but got %v", out, addrs) } - address := socketAddr.SocketAddress.Address - out = append(out, address) - } - if !slices.Equal(out, addrs) { - t.Errorf("OverwriteDNSCluster error, expected %v, but got %v", out, addrs) } } @@ -203,7 +226,7 @@ func TestHandleCdsResponseWithDns(t *testing.T) { defer close(stopCh) dnsResolver, err := NewDnsResolver(p.Cache) assert.NoError(t, err) - dnsResolver.StartKernelNativeDnsController(stopCh) + dnsResolver.Run(stopCh) p.DnsResolverChan = dnsResolver.Clusters for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { @@ -307,7 +330,7 @@ func TestGetPendingResolveDomain(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := getPendingResolveDomain(tt.args.clusters); !reflect.DeepEqual(got, tt.want) { + if got, _ := getPendingResolveDomain(tt.args.clusters); !reflect.DeepEqual(got, tt.want) { t.Errorf("getPendingResolveDomain() = %v, want %v", got, tt.want) } }) diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 4845f58e3..1c9886d25 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -27,7 +27,6 @@ import ( bpfads "kmesh.net/kmesh/pkg/bpf/ads" bpfwl "kmesh.net/kmesh/pkg/bpf/workload" "kmesh.net/kmesh/pkg/constants" - "kmesh.net/kmesh/pkg/controller/ads" "kmesh.net/kmesh/pkg/controller/bypass" "kmesh.net/kmesh/pkg/controller/encryption/ipsec" manage "kmesh.net/kmesh/pkg/controller/manage" @@ -153,15 +152,8 @@ func (c *Controller) Start(stopCh <-chan struct{}) error { if c.client.WorkloadController != nil { c.client.WorkloadController.Run(ctx) - } - - if c.client.AdsController != nil { - dnsResolver, err := ads.NewDnsResolver(c.client.AdsController.Processor.Cache) - if err != nil { - return fmt.Errorf("dns resolver of Kernel-Native mode create failed: %v", err) - } - dnsResolver.StartKernelNativeDnsController(stopCh) - c.client.AdsController.Processor.DnsResolverChan = dnsResolver.Clusters + } else { + c.client.AdsController.StartDnsController(stopCh) } return c.client.Run(stopCh) diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index 0d08255fe..51784d8cd 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -147,8 +147,7 @@ func (r *DNSResolver) resolve(domainName string) ([]string, time.Duration, error return addrs, ttl, nil } -// resolveDomains takes a slice of cluster -func (r *DNSResolver) ResolveDomains(domainName string) { +func (r *DNSResolver) InitializeDomainInCache(domainName string) { r.Lock() if r.cache[domainName] == nil { r.cache[domainName] = &DomainCacheEntry{} @@ -266,16 +265,16 @@ func (r *DNSResolver) GetAllCachedDomains() []string { return out } -func (r *DNSResolver) GetOneDomainFromCache(domain string) ([]string, bool) { - r.Lock() +func (r *DNSResolver) GetDomainAddress(domain string) ([]string, bool) { + r.RLock() addresses, ok := r.cache[domain] - r.Unlock() + r.RUnlock() return addresses.Addresses, ok } -func (r *DNSResolver) GetAddressesFromCache(domains map[string]struct{}) map[string]*DomainCacheEntry { - r.Lock() - defer r.Unlock() +func (r *DNSResolver) GetBatchAddressesFromCache(domains map[string]struct{}) map[string]*DomainCacheEntry { + r.RLock() + defer r.RUnlock() alreadyResolveDomains := make(map[string]*DomainCacheEntry) for domain := range domains { @@ -298,7 +297,7 @@ func (r *DNSResolver) RemoveUnwatchDomain(domains map[string]struct{}) { } } -func (r *DNSResolver) AddDomainIntoRefreshQueue(info *DomainInfo, time time.Duration) { +func (r *DNSResolver) ScheduleDomainRefresh(info *DomainInfo, time time.Duration) { if info == nil { return } From c4cb113e391824db00afa6f996dccda47817e393 Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Mon, 24 Mar 2025 19:48:36 +0800 Subject: [PATCH 08/11] add more comments Signed-off-by: LiZhenCheng9527 --- pkg/controller/ads/dns.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pkg/controller/ads/dns.go b/pkg/controller/ads/dns.go index 0cad2b1bc..85e7aef75 100644 --- a/pkg/controller/ads/dns.go +++ b/pkg/controller/ads/dns.go @@ -70,9 +70,11 @@ func NewDnsResolver(adsCache *AdsCache) (*dnsController, error) { } func (r *dnsController) Run(stopCh <-chan struct{}) { - // start dns resolver + // Start dns resolver go r.dnsResolver.StartDnsResolver(stopCh) + // Handle cds updates when a hostname completes resolution go r.refreshAdsWorker(stopCh) + // Consumption of clusters to be resolved. go r.startDnsController() go func() { <-stopCh @@ -92,13 +94,14 @@ func (r *dnsController) resolveDomains(cds []*clusterv3.Cluster) { // store all pending hostnames of clusters in r.hostInfo for _, cluster := range cds { clusterName := cluster.GetName() - info := getHostInfo(cluster) + info := getHostName(cluster) r.pendingClusterInfo[clusterName] = info } // delete any scheduled re-resolve for domains we no longer care about r.dnsResolver.RemoveUnwatchDomain(hostNames) + // Update clusters based on the data in the dns cache. for k, v := range domains { addresses := r.dnsResolver.GetDNSAddresses(k) // Already have record in dns cache @@ -106,6 +109,8 @@ func (r *dnsController) resolveDomains(cds []*clusterv3.Cluster) { r.updateClusters(v, addresses) go r.cache.ClusterCache.Flush() } else { + // Initialize the newly added hostname + // and add it to the dns queue to be resolved. r.dnsResolver.InitializeDomainInCache(k) domainInfo := &dns.DomainInfo{ Domain: v.DomainName, @@ -116,6 +121,7 @@ func (r *dnsController) resolveDomains(cds []*clusterv3.Cluster) { } } +// Handle cds updates when a hostname completes resolution func (r *dnsController) refreshAdsWorker(stop <-chan struct{}) { for { select { @@ -223,7 +229,8 @@ func buildLbEndpoints(port uint32, addrs []string) []*endpointv3.LbEndpoint { return lbEndpoints } -func getHostInfo(cluster *clusterv3.Cluster) []string { +// Get the hostname to be resolved in Cluster +func getHostName(cluster *clusterv3.Cluster) []string { info := []string{} for _, e := range cluster.LoadAssignment.Endpoints { for _, le := range e.LbEndpoints { From 44d5b2b257ab23330412bd5f69931988286b683a Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Tue, 25 Mar 2025 09:22:33 +0800 Subject: [PATCH 09/11] fix go lint error Signed-off-by: LiZhenCheng9527 --- pkg/controller/ads/ads_controller.go | 2 +- pkg/controller/ads/dns.go | 68 +++++++++++++++------------- pkg/controller/ads/dns_test.go | 11 ++--- pkg/dns/dns.go | 4 +- 4 files changed, 45 insertions(+), 40 deletions(-) diff --git a/pkg/controller/ads/ads_controller.go b/pkg/controller/ads/ads_controller.go index cd84d8e9a..bf8d06891 100644 --- a/pkg/controller/ads/ads_controller.go +++ b/pkg/controller/ads/ads_controller.go @@ -47,7 +47,7 @@ type connection struct { func NewController(bpfAds *bpfads.BpfAds) *Controller { processor := newProcessor(bpfAds) // create kernel-native mode ads resolver controller - dnsResolverController, err := NewDnsResolver(processor.Cache) + dnsResolverController, err := NewDnsController(processor.Cache) if err != nil { log.Errorf("dns resolver of Kernel-Native mode create failed: %v", err) } diff --git a/pkg/controller/ads/dns.go b/pkg/controller/ads/dns.go index 85e7aef75..b524d72c6 100644 --- a/pkg/controller/ads/dns.go +++ b/pkg/controller/ads/dns.go @@ -42,7 +42,7 @@ type dnsController struct { // Store the copy of pendingResolveDomain. clusterCache map[string]*pendingResolveDomain // store all pending hostnames in the clusters - pendingClusterInfo map[string][]string + pendingHostnames map[string][]string sync.RWMutex } @@ -50,52 +50,51 @@ type dnsController struct { // domain name is used for dns resolution // cluster is used for create the apicluster type pendingResolveDomain struct { - DomainName string Clusters []*clusterv3.Cluster RefreshRate time.Duration } -func NewDnsResolver(adsCache *AdsCache) (*dnsController, error) { +func NewDnsController(adsCache *AdsCache) (*dnsController, error) { resolver, err := dns.NewDNSResolver() if err != nil { return nil, err } return &dnsController{ - Clusters: make(chan []*clusterv3.Cluster), - cache: adsCache, - dnsResolver: resolver, - clusterCache: make(map[string]*pendingResolveDomain), - pendingClusterInfo: make(map[string][]string), + Clusters: make(chan []*clusterv3.Cluster), + cache: adsCache, + dnsResolver: resolver, + clusterCache: make(map[string]*pendingResolveDomain), + pendingHostnames: make(map[string][]string), }, nil } func (r *dnsController) Run(stopCh <-chan struct{}) { // Start dns resolver go r.dnsResolver.StartDnsResolver(stopCh) - // Handle cds updates when a hostname completes resolution - go r.refreshAdsWorker(stopCh) - // Consumption of clusters to be resolved. - go r.startDnsController() + // Handle cds updates + go r.refreshWorker(stopCh) + // Consumption of clusters. + go r.processClusterDomains() go func() { <-stopCh close(r.Clusters) }() } -func (r *dnsController) startDnsController() { +func (r *dnsController) processClusterDomains() { for clusters := range r.Clusters { - r.resolveDomains(clusters) + r.getDomains(clusters) } } -func (r *dnsController) resolveDomains(cds []*clusterv3.Cluster) { +func (r *dnsController) getDomains(cds []*clusterv3.Cluster) { domains, hostNames := getPendingResolveDomain(cds) - // store all pending hostnames of clusters in r.hostInfo + // store all pending hostnames of clusters in pendingHostnames for _, cluster := range cds { clusterName := cluster.GetName() info := getHostName(cluster) - r.pendingClusterInfo[clusterName] = info + r.pendingHostnames[clusterName] = info } // delete any scheduled re-resolve for domains we no longer care about @@ -106,23 +105,23 @@ func (r *dnsController) resolveDomains(cds []*clusterv3.Cluster) { addresses := r.dnsResolver.GetDNSAddresses(k) // Already have record in dns cache if addresses != nil { - r.updateClusters(v, addresses) - go r.cache.ClusterCache.Flush() + // k(domain) has been resolved, triggering refreshWorker. + r.dnsResolver.DnsChan <- k } else { // Initialize the newly added hostname // and add it to the dns queue to be resolved. - r.dnsResolver.InitializeDomainInCache(k) + r.dnsResolver.AddDomainInCache(k) domainInfo := &dns.DomainInfo{ - Domain: v.DomainName, + Domain: k, RefreshRate: v.RefreshRate, } - r.dnsResolver.ScheduleDomainRefresh(domainInfo, 0) + r.dnsResolver.AddDomainInQueue(domainInfo, 0) } } } -// Handle cds updates when a hostname completes resolution -func (r *dnsController) refreshAdsWorker(stop <-chan struct{}) { +// Handle cds updates +func (r *dnsController) refreshWorker(stop <-chan struct{}) { for { select { case <-stop: @@ -131,29 +130,37 @@ func (r *dnsController) refreshAdsWorker(stop <-chan struct{}) { domain := <-r.dnsResolver.DnsChan pendingDomain := r.getClustersByDomain(domain) addrs := r.dnsResolver.GetDNSAddresses(domain) - r.updateClusters(pendingDomain, addrs) + ready := r.updateClusters(pendingDomain, domain, addrs) + if ready { + go r.cache.ClusterCache.Flush() + } } } } -func (r *dnsController) updateClusters(pendingDomain *pendingResolveDomain, addrs []string) { +func (r *dnsController) updateClusters(pendingDomain *pendingResolveDomain, domain string, addrs []string) bool { + isClusterUpdate := false if pendingDomain == nil || addrs == nil { - return + return false } for _, cluster := range pendingDomain.Clusters { - ready, newCluster := r.overwriteDnsCluster(cluster, pendingDomain.DomainName, addrs) + ready, newCluster := r.overwriteDnsCluster(cluster, domain, addrs) if ready { if !r.cache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, newCluster) { log.Debugf("cluster: %s is deleted", cluster.Name) - return + return false + } else { + isClusterUpdate = true } } } + // if one cluster update successful, we will retuen true + return isClusterUpdate } func (r *dnsController) overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []string) (bool, *clusterv3.Cluster) { ready := true - hostNames := r.pendingClusterInfo[cluster.GetName()] + hostNames := r.pendingHostnames[cluster.GetName()] addressesOfHostname := make(map[string][]string) for _, hostName := range hostNames { @@ -273,7 +280,6 @@ func getPendingResolveDomain(cds []*clusterv3.Cluster) (map[string]*pendingResol v.Clusters = append(v.Clusters, cluster) } else { domainWithRefreshRate := &pendingResolveDomain{ - DomainName: address, Clusters: []*clusterv3.Cluster{cluster}, RefreshRate: cluster.GetDnsRefreshRate().AsDuration(), } diff --git a/pkg/controller/ads/dns_test.go b/pkg/controller/ads/dns_test.go index ba6d43ac5..15556acca 100644 --- a/pkg/controller/ads/dns_test.go +++ b/pkg/controller/ads/dns_test.go @@ -77,11 +77,11 @@ func TestOverwriteDNSCluster(t *testing.T) { p := NewController(nil).Processor stopCh := make(chan struct{}) defer close(stopCh) - dnsResolver, err := NewDnsResolver(p.Cache) + dnsResolver, err := NewDnsController(p.Cache) assert.NoError(t, err) p.DnsResolverChan = dnsResolver.Clusters - dnsResolver.pendingClusterInfo = map[string][]string{ - cluster.GetName(): []string{ + dnsResolver.pendingHostnames = map[string][]string{ + cluster.GetName(): { domain, }, } @@ -224,7 +224,7 @@ func TestHandleCdsResponseWithDns(t *testing.T) { p := NewController(nil).Processor stopCh := make(chan struct{}) defer close(stopCh) - dnsResolver, err := NewDnsResolver(p.Cache) + dnsResolver, err := NewDnsController(p.Cache) assert.NoError(t, err) dnsResolver.Run(stopCh) p.DnsResolverChan = dnsResolver.Clusters @@ -322,8 +322,7 @@ func TestGetPendingResolveDomain(t *testing.T) { }, want: map[string]*pendingResolveDomain{ "www.google.com": { - DomainName: "www.google.com", - Clusters: []*clusterv3.Cluster{&utClusterWithHost}, + Clusters: []*clusterv3.Cluster{&utClusterWithHost}, }, }, }, diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index 51784d8cd..5accea0fa 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -147,7 +147,7 @@ func (r *DNSResolver) resolve(domainName string) ([]string, time.Duration, error return addrs, ttl, nil } -func (r *DNSResolver) InitializeDomainInCache(domainName string) { +func (r *DNSResolver) AddDomainInCache(domainName string) { r.Lock() if r.cache[domainName] == nil { r.cache[domainName] = &DomainCacheEntry{} @@ -297,7 +297,7 @@ func (r *DNSResolver) RemoveUnwatchDomain(domains map[string]struct{}) { } } -func (r *DNSResolver) ScheduleDomainRefresh(info *DomainInfo, time time.Duration) { +func (r *DNSResolver) AddDomainInQueue(info *DomainInfo, time time.Duration) { if info == nil { return } From 282fddadf9c0ac3063262d4baca563c716446b54 Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Thu, 27 Mar 2025 17:17:50 +0800 Subject: [PATCH 10/11] change function name Signed-off-by: LiZhenCheng9527 --- pkg/controller/ads/dns.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/controller/ads/dns.go b/pkg/controller/ads/dns.go index b524d72c6..91d92357c 100644 --- a/pkg/controller/ads/dns.go +++ b/pkg/controller/ads/dns.go @@ -74,20 +74,20 @@ func (r *dnsController) Run(stopCh <-chan struct{}) { // Handle cds updates go r.refreshWorker(stopCh) // Consumption of clusters. - go r.processClusterDomains() + go r.processClusters() go func() { <-stopCh close(r.Clusters) }() } -func (r *dnsController) processClusterDomains() { +func (r *dnsController) processClusters() { for clusters := range r.Clusters { - r.getDomains(clusters) + r.processDomains(clusters) } } -func (r *dnsController) getDomains(cds []*clusterv3.Cluster) { +func (r *dnsController) processDomains(cds []*clusterv3.Cluster) { domains, hostNames := getPendingResolveDomain(cds) // store all pending hostnames of clusters in pendingHostnames From 2f054db3708a64383b86665661014f94d1ae5fbf Mon Sep 17 00:00:00 2001 From: LiZhenCheng9527 Date: Fri, 28 Mar 2025 16:10:47 +0800 Subject: [PATCH 11/11] Modified based on comments Signed-off-by: LiZhenCheng9527 --- pkg/controller/ads/ads_controller.go | 3 +- pkg/controller/ads/dns.go | 63 +++++++++++++--------------- pkg/controller/ads/dns_test.go | 23 +++++----- pkg/dns/dns.go | 18 ++++---- 4 files changed, 53 insertions(+), 54 deletions(-) diff --git a/pkg/controller/ads/ads_controller.go b/pkg/controller/ads/ads_controller.go index bf8d06891..c555b55a9 100644 --- a/pkg/controller/ads/ads_controller.go +++ b/pkg/controller/ads/ads_controller.go @@ -50,8 +50,9 @@ func NewController(bpfAds *bpfads.BpfAds) *Controller { dnsResolverController, err := NewDnsController(processor.Cache) if err != nil { log.Errorf("dns resolver of Kernel-Native mode create failed: %v", err) + return nil } - processor.DnsResolverChan = dnsResolverController.Clusters + processor.DnsResolverChan = dnsResolverController.clustersChan return &Controller{ dnsResolverController: dnsResolverController, diff --git a/pkg/controller/ads/dns.go b/pkg/controller/ads/dns.go index 91d92357c..4466a7e90 100644 --- a/pkg/controller/ads/dns.go +++ b/pkg/controller/ads/dns.go @@ -36,9 +36,9 @@ import ( // adsDnsResolver is DNS resolver of Kernel Native type dnsController struct { - Clusters chan []*clusterv3.Cluster - cache *AdsCache - dnsResolver *dns.DNSResolver + clustersChan chan []*clusterv3.Cluster + cache *AdsCache + dnsResolver *dns.DNSResolver // Store the copy of pendingResolveDomain. clusterCache map[string]*pendingResolveDomain // store all pending hostnames in the clusters @@ -47,7 +47,6 @@ type dnsController struct { } // pending resolve domain info of Kennel-Native Mode, -// domain name is used for dns resolution // cluster is used for create the apicluster type pendingResolveDomain struct { Clusters []*clusterv3.Cluster @@ -60,7 +59,7 @@ func NewDnsController(adsCache *AdsCache) (*dnsController, error) { return nil, err } return &dnsController{ - Clusters: make(chan []*clusterv3.Cluster), + clustersChan: make(chan []*clusterv3.Cluster), cache: adsCache, dnsResolver: resolver, clusterCache: make(map[string]*pendingResolveDomain), @@ -77,18 +76,18 @@ func (r *dnsController) Run(stopCh <-chan struct{}) { go r.processClusters() go func() { <-stopCh - close(r.Clusters) + close(r.clustersChan) }() } func (r *dnsController) processClusters() { - for clusters := range r.Clusters { + for clusters := range r.clustersChan { r.processDomains(clusters) } } func (r *dnsController) processDomains(cds []*clusterv3.Cluster) { - domains, hostNames := getPendingResolveDomain(cds) + domains := getPendingResolveDomain(cds) // store all pending hostnames of clusters in pendingHostnames for _, cluster := range cds { @@ -98,22 +97,22 @@ func (r *dnsController) processDomains(cds []*clusterv3.Cluster) { } // delete any scheduled re-resolve for domains we no longer care about - r.dnsResolver.RemoveUnwatchDomain(hostNames) + r.dnsResolver.RemoveUnwatchDomain(domains) // Update clusters based on the data in the dns cache. for k, v := range domains { addresses := r.dnsResolver.GetDNSAddresses(k) // Already have record in dns cache if addresses != nil { - // k(domain) has been resolved, triggering refreshWorker. - r.dnsResolver.DnsChan <- k + // Use a goroutine to update the Cluster, reducing the processing time of functions + // Avoiding clusterChan blocking + go r.updateClusters(v.(*pendingResolveDomain), k, addresses) } else { // Initialize the newly added hostname // and add it to the dns queue to be resolved. - r.dnsResolver.AddDomainInCache(k) domainInfo := &dns.DomainInfo{ Domain: k, - RefreshRate: v.RefreshRate, + RefreshRate: v.(*pendingResolveDomain).RefreshRate, } r.dnsResolver.AddDomainInQueue(domainInfo, 0) } @@ -126,36 +125,33 @@ func (r *dnsController) refreshWorker(stop <-chan struct{}) { select { case <-stop: return - default: - domain := <-r.dnsResolver.DnsChan + case domain := <-r.dnsResolver.DnsChan: pendingDomain := r.getClustersByDomain(domain) addrs := r.dnsResolver.GetDNSAddresses(domain) - ready := r.updateClusters(pendingDomain, domain, addrs) - if ready { - go r.cache.ClusterCache.Flush() - } + r.updateClusters(pendingDomain, domain, addrs) } } } -func (r *dnsController) updateClusters(pendingDomain *pendingResolveDomain, domain string, addrs []string) bool { +func (r *dnsController) updateClusters(pendingDomain *pendingResolveDomain, domain string, addrs []string) { isClusterUpdate := false if pendingDomain == nil || addrs == nil { - return false + return } for _, cluster := range pendingDomain.Clusters { ready, newCluster := r.overwriteDnsCluster(cluster, domain, addrs) if ready { if !r.cache.UpdateApiClusterIfExists(core_v2.ApiStatus_UPDATE, newCluster) { log.Debugf("cluster: %s is deleted", cluster.Name) - return false } else { isClusterUpdate = true } } } // if one cluster update successful, we will retuen true - return isClusterUpdate + if isClusterUpdate { + r.cache.ClusterCache.Flush() + } } func (r *dnsController) overwriteDnsCluster(cluster *clusterv3.Cluster, domain string, addrs []string) (bool, *clusterv3.Cluster) { @@ -255,9 +251,9 @@ func getHostName(cluster *clusterv3.Cluster) []string { return info } -func getPendingResolveDomain(cds []*clusterv3.Cluster) (map[string]*pendingResolveDomain, map[string]struct{}) { - domains := make(map[string]*pendingResolveDomain) - hostNames := make(map[string]struct{}) +func getPendingResolveDomain(cds []*clusterv3.Cluster) map[string]interface{} { + domains := make(map[string]interface{}) + // hostNames := make(map[string]struct{}) for _, cluster := range cds { if cluster.LoadAssignment == nil { @@ -277,26 +273,26 @@ func getPendingResolveDomain(cds []*clusterv3.Cluster) (map[string]*pendingResol } if v, ok := domains[address]; ok { - v.Clusters = append(v.Clusters, cluster) + v.(*pendingResolveDomain).Clusters = append(v.(*pendingResolveDomain).Clusters, cluster) } else { domainWithRefreshRate := &pendingResolveDomain{ Clusters: []*clusterv3.Cluster{cluster}, RefreshRate: cluster.GetDnsRefreshRate().AsDuration(), } domains[address] = domainWithRefreshRate - hostNames[address] = struct{}{} } } } } - return domains, hostNames + return domains } func (r *dnsController) newClusterCache() { + r.Lock() + defer r.Unlock() + if r.clusterCache != nil { - r.Lock() - defer r.Unlock() log.Debug("clean up dns clusters") r.clusterCache = map[string]*pendingResolveDomain{} return @@ -304,9 +300,10 @@ func (r *dnsController) newClusterCache() { } func (r *dnsController) getClustersByDomain(domain string) *pendingResolveDomain { + r.RLock() + defer r.RUnlock() + if r.clusterCache != nil { - r.RLock() - defer r.RUnlock() if v, ok := r.clusterCache[domain]; ok { return v } diff --git a/pkg/controller/ads/dns_test.go b/pkg/controller/ads/dns_test.go index 15556acca..34c5ef87a 100644 --- a/pkg/controller/ads/dns_test.go +++ b/pkg/controller/ads/dns_test.go @@ -79,7 +79,7 @@ func TestOverwriteDNSCluster(t *testing.T) { defer close(stopCh) dnsResolver, err := NewDnsController(p.Cache) assert.NoError(t, err) - p.DnsResolverChan = dnsResolver.Clusters + p.DnsResolverChan = dnsResolver.clustersChan dnsResolver.pendingHostnames = map[string][]string{ cluster.GetName(): { domain, @@ -227,11 +227,11 @@ func TestHandleCdsResponseWithDns(t *testing.T) { dnsResolver, err := NewDnsController(p.Cache) assert.NoError(t, err) dnsResolver.Run(stopCh) - p.DnsResolverChan = dnsResolver.Clusters + p.DnsResolverChan = dnsResolver.clustersChan for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { // notify dns resolver - dnsResolver.Clusters <- tc.clusters + dnsResolver.clustersChan <- tc.clusters retry.UntilOrFail(t, func() bool { return slices.EqualUnordered(tc.expected, dnsResolver.dnsResolver.GetAllCachedDomains()) }, retry.Timeout(1*time.Second)) @@ -302,7 +302,7 @@ func TestGetPendingResolveDomain(t *testing.T) { tests := []struct { name string args args - want map[string]*pendingResolveDomain + want map[string]interface{} }{ { name: "empty domains test", @@ -311,7 +311,7 @@ func TestGetPendingResolveDomain(t *testing.T) { &utCluster, }, }, - want: map[string]*pendingResolveDomain{}, + want: map[string]interface{}{}, }, { name: "cluster domain is not IP", @@ -320,18 +320,19 @@ func TestGetPendingResolveDomain(t *testing.T) { &utClusterWithHost, }, }, - want: map[string]*pendingResolveDomain{ - "www.google.com": { - Clusters: []*clusterv3.Cluster{&utClusterWithHost}, + want: map[string]interface{}{ + "www.google.com": &pendingResolveDomain{ + Clusters: []*clusterv3.Cluster{ + &utClusterWithHost, + }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got, _ := getPendingResolveDomain(tt.args.clusters); !reflect.DeepEqual(got, tt.want) { - t.Errorf("getPendingResolveDomain() = %v, want %v", got, tt.want) - } + got := getPendingResolveDomain(tt.args.clusters) + assert.Equal(t, tt.want, got) }) } } diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index 5accea0fa..48787a80c 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -147,14 +147,6 @@ func (r *DNSResolver) resolve(domainName string) ([]string, time.Duration, error return addrs, ttl, nil } -func (r *DNSResolver) AddDomainInCache(domainName string) { - r.Lock() - if r.cache[domainName] == nil { - r.cache[domainName] = &DomainCacheEntry{} - } - r.Unlock() -} - // doResolve is copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go. func (r *DNSResolver) doResolve(domain string) ([]string, time.Duration, error) { var out []string @@ -285,7 +277,7 @@ func (r *DNSResolver) GetBatchAddressesFromCache(domains map[string]struct{}) ma return alreadyResolveDomains } -func (r *DNSResolver) RemoveUnwatchDomain(domains map[string]struct{}) { +func (r *DNSResolver) RemoveUnwatchDomain(domains map[string]interface{}) { r.Lock() defer r.Unlock() @@ -301,5 +293,13 @@ func (r *DNSResolver) AddDomainInQueue(info *DomainInfo, time time.Duration) { if info == nil { return } + + // init pending domain in dns cache + r.Lock() + if r.cache[info.Domain] == nil { + r.cache[info.Domain] = &DomainCacheEntry{} + } + r.Unlock() + r.refreshQueue.AddAfter(info, time) }