Skip to content

Commit 115757a

Browse files
committed
feat(rule): support rule subscribe
1 parent 73dabb2 commit 115757a

File tree

3 files changed

+452
-63
lines changed

3 files changed

+452
-63
lines changed

rule.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ func (r Rule) String() string {
4646
r.Priority, from, to, r.Table, r.typeString())
4747
}
4848

49+
// RuleUpdate is sent when a route changes - type is RTM_NEWRULE or RTM_DELRULE
50+
type RuleUpdate struct {
51+
Type uint16
52+
Rule
53+
}
54+
4955
// NewRule return empty rules.
5056
func NewRule() *Rule {
5157
return &Rule{

rule_linux.go

Lines changed: 193 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import (
55
"errors"
66
"fmt"
77
"net"
8+
"syscall"
89

910
"github.com/vishvananda/netlink/nl"
11+
"github.com/vishvananda/netns"
1012
"golang.org/x/sys/unix"
1113
)
1214

@@ -227,73 +229,11 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) (
227229

228230
var res = make([]Rule, 0)
229231
for i := range msgs {
230-
msg := nl.DeserializeRtMsg(msgs[i])
231-
attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():])
232+
rule, err := deserializeRule(msgs[i])
232233
if err != nil {
233234
return nil, err
234235
}
235236

236-
rule := NewRule()
237-
rule.Priority = 0 // The default priority from kernel
238-
239-
rule.Invert = msg.Flags&FibRuleInvert > 0
240-
rule.Family = int(msg.Family)
241-
rule.Tos = uint(msg.Tos)
242-
243-
for j := range attrs {
244-
switch attrs[j].Attr.Type {
245-
case unix.RTA_TABLE:
246-
rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
247-
case nl.FRA_SRC:
248-
rule.Src = &net.IPNet{
249-
IP: attrs[j].Value,
250-
Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)),
251-
}
252-
case nl.FRA_DST:
253-
rule.Dst = &net.IPNet{
254-
IP: attrs[j].Value,
255-
Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
256-
}
257-
case nl.FRA_FWMARK:
258-
rule.Mark = native.Uint32(attrs[j].Value[0:4])
259-
case nl.FRA_FWMASK:
260-
mask := native.Uint32(attrs[j].Value[0:4])
261-
rule.Mask = &mask
262-
case nl.FRA_TUN_ID:
263-
rule.TunID = uint(native.Uint64(attrs[j].Value[0:8]))
264-
case nl.FRA_IIFNAME:
265-
rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
266-
case nl.FRA_OIFNAME:
267-
rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
268-
case nl.FRA_SUPPRESS_PREFIXLEN:
269-
i := native.Uint32(attrs[j].Value[0:4])
270-
if i != 0xffffffff {
271-
rule.SuppressPrefixlen = int(i)
272-
}
273-
case nl.FRA_SUPPRESS_IFGROUP:
274-
i := native.Uint32(attrs[j].Value[0:4])
275-
if i != 0xffffffff {
276-
rule.SuppressIfgroup = int(i)
277-
}
278-
case nl.FRA_FLOW:
279-
rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
280-
case nl.FRA_GOTO:
281-
rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
282-
case nl.FRA_PRIORITY:
283-
rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
284-
case nl.FRA_IP_PROTO:
285-
rule.IPProto = int(native.Uint32(attrs[j].Value[0:4]))
286-
case nl.FRA_DPORT_RANGE:
287-
rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
288-
case nl.FRA_SPORT_RANGE:
289-
rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
290-
case nl.FRA_UID_RANGE:
291-
rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8]))
292-
case nl.FRA_PROTOCOL:
293-
rule.Protocol = uint8(attrs[j].Value[0])
294-
}
295-
}
296-
297237
if filter != nil {
298238
switch {
299239
case filterMask&RT_FILTER_SRC != 0 &&
@@ -376,3 +316,193 @@ func (r Rule) typeString() string {
376316
return fmt.Sprintf("type(0x%x)", r.Type)
377317
}
378318
}
319+
320+
// deserializeRule decodes a binary netlink message into a Rule struct
321+
func deserializeRule(m []byte) (*Rule, error) {
322+
msg := nl.DeserializeRtMsg(m)
323+
attrs, err := nl.ParseRouteAttr(m[msg.Len():])
324+
if err != nil {
325+
return nil, err
326+
}
327+
328+
rule := NewRule()
329+
rule.Priority = 0 // The default priority from kernel
330+
331+
rule.Invert = msg.Flags&FibRuleInvert > 0
332+
rule.Family = int(msg.Family)
333+
rule.Tos = uint(msg.Tos)
334+
335+
for j := range attrs {
336+
switch attrs[j].Attr.Type {
337+
case unix.RTA_TABLE:
338+
rule.Table = int(native.Uint32(attrs[j].Value[0:4]))
339+
case nl.FRA_SRC:
340+
rule.Src = &net.IPNet{
341+
IP: attrs[j].Value,
342+
Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)),
343+
}
344+
case nl.FRA_DST:
345+
rule.Dst = &net.IPNet{
346+
IP: attrs[j].Value,
347+
Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)),
348+
}
349+
case nl.FRA_FWMARK:
350+
rule.Mark = native.Uint32(attrs[j].Value[0:4])
351+
case nl.FRA_FWMASK:
352+
mask := native.Uint32(attrs[j].Value[0:4])
353+
rule.Mask = &mask
354+
case nl.FRA_TUN_ID:
355+
rule.TunID = uint(native.Uint64(attrs[j].Value[0:8]))
356+
case nl.FRA_IIFNAME:
357+
rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
358+
case nl.FRA_OIFNAME:
359+
rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1])
360+
case nl.FRA_SUPPRESS_PREFIXLEN:
361+
i := native.Uint32(attrs[j].Value[0:4])
362+
if i != 0xffffffff {
363+
rule.SuppressPrefixlen = int(i)
364+
}
365+
case nl.FRA_SUPPRESS_IFGROUP:
366+
i := native.Uint32(attrs[j].Value[0:4])
367+
if i != 0xffffffff {
368+
rule.SuppressIfgroup = int(i)
369+
}
370+
case nl.FRA_FLOW:
371+
rule.Flow = int(native.Uint32(attrs[j].Value[0:4]))
372+
case nl.FRA_GOTO:
373+
rule.Goto = int(native.Uint32(attrs[j].Value[0:4]))
374+
case nl.FRA_PRIORITY:
375+
rule.Priority = int(native.Uint32(attrs[j].Value[0:4]))
376+
case nl.FRA_IP_PROTO:
377+
rule.IPProto = int(native.Uint32(attrs[j].Value[0:4]))
378+
case nl.FRA_DPORT_RANGE:
379+
rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
380+
case nl.FRA_SPORT_RANGE:
381+
rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4]))
382+
case nl.FRA_UID_RANGE:
383+
rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8]))
384+
case nl.FRA_PROTOCOL:
385+
rule.Protocol = uint8(attrs[j].Value[0])
386+
}
387+
}
388+
389+
return rule, nil
390+
}
391+
392+
// RuleSubscribe takes a chan down which notifications will be sent
393+
// when rules are added or deleted. Close the 'done' chan to stop subscription.
394+
func RuleSubscribe(ch chan<- RuleUpdate, done <-chan struct{}) error {
395+
return ruleSubscribeAt(netns.None(), netns.None(), ch, done, nil, false, 0, nil, false)
396+
}
397+
398+
// RuleSubscribeAt works like RuleSubscribe plus it allows the caller
399+
// to choose the network namespace in which to subscribe (ns).
400+
func RuleSubscribeAt(ns netns.NsHandle, ch chan<- RuleUpdate, done <-chan struct{}) error {
401+
return ruleSubscribeAt(ns, netns.None(), ch, done, nil, false, 0, nil, false)
402+
}
403+
404+
// RuleSubscribeOptions contains a set of options to use with
405+
// RuleSubscribeWithOptions.
406+
type RuleSubscribeOptions struct {
407+
Namespace *netns.NsHandle
408+
ErrorCallback func(error)
409+
ListExisting bool
410+
ReceiveBufferSize int
411+
ReceiveBufferForceSize bool
412+
ReceiveTimeout *unix.Timeval
413+
}
414+
415+
// RuleSubscribeWithOptions work like RuleSubscribe but enable to
416+
// provide additional options to modify the behavior. Currently, the
417+
// namespace can be provided as well as an error callback.
418+
func RuleSubscribeWithOptions(ch chan<- RuleUpdate, done <-chan struct{}, options RuleSubscribeOptions) error {
419+
if options.Namespace == nil {
420+
none := netns.None()
421+
options.Namespace = &none
422+
}
423+
return ruleSubscribeAt(*options.Namespace, netns.None(), ch, done, options.ErrorCallback, options.ListExisting,
424+
options.ReceiveBufferSize, options.ReceiveTimeout, options.ReceiveBufferForceSize)
425+
}
426+
427+
func ruleSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- RuleUpdate, done <-chan struct{}, cberr func(error), listExisting bool,
428+
rcvbuf int, rcvTimeout *unix.Timeval, rcvbufForce bool) error {
429+
s, err := nl.SubscribeAt(newNs, curNs, unix.NETLINK_ROUTE, unix.RTNLGRP_IPV4_RULE, unix.RTNLGRP_IPV6_RULE)
430+
if err != nil {
431+
return err
432+
}
433+
if rcvTimeout != nil {
434+
if err := s.SetReceiveTimeout(rcvTimeout); err != nil {
435+
return err
436+
}
437+
}
438+
if rcvbuf != 0 {
439+
err = s.SetReceiveBufferSize(rcvbuf, rcvbufForce)
440+
if err != nil {
441+
return err
442+
}
443+
}
444+
if done != nil {
445+
go func() {
446+
<-done
447+
s.Close()
448+
}()
449+
}
450+
if listExisting {
451+
req := pkgHandle.newNetlinkRequest(unix.RTM_GETRULE,
452+
unix.NLM_F_DUMP)
453+
infmsg := nl.NewIfInfomsg(unix.AF_UNSPEC)
454+
req.AddData(infmsg)
455+
if err := s.Send(req); err != nil {
456+
return err
457+
}
458+
}
459+
go func() {
460+
defer close(ch)
461+
for {
462+
msgs, from, err := s.Receive()
463+
if err != nil {
464+
if cberr != nil {
465+
cberr(fmt.Errorf("Receive failed: %v",
466+
err))
467+
}
468+
if err.Error() != "use of closed file" && err.Error() != "file already closed" {
469+
panic(fmt.Sprintf("err: %v", err))
470+
}
471+
return
472+
}
473+
if from.Pid != nl.PidKernel {
474+
if cberr != nil {
475+
cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
476+
}
477+
panic(fmt.Sprintf("from.Pid: %d, nl.PidKernel: %d", from.Pid, nl.PidKernel))
478+
continue
479+
}
480+
for _, m := range msgs {
481+
if m.Header.Type == unix.NLMSG_DONE {
482+
continue
483+
}
484+
if m.Header.Type == unix.NLMSG_ERROR {
485+
error := int32(native.Uint32(m.Data[0:4]))
486+
if error == 0 {
487+
continue
488+
}
489+
if cberr != nil {
490+
cberr(fmt.Errorf("error message: %v",
491+
syscall.Errno(-error)))
492+
}
493+
continue
494+
}
495+
rule, err := deserializeRule(m.Data)
496+
if err != nil {
497+
if cberr != nil {
498+
cberr(err)
499+
}
500+
continue
501+
}
502+
ch <- RuleUpdate{Type: m.Header.Type, Rule: *rule}
503+
}
504+
}
505+
}()
506+
507+
return nil
508+
}

0 commit comments

Comments
 (0)