Skip to content

Commit 51c44dc

Browse files
authored
Implement AddGenerationalMonitor to deliver monitor events in batches (google#283)
1 parent ed578af commit 51c44dc

File tree

3 files changed

+180
-55
lines changed

3 files changed

+180
-55
lines changed

gen.go

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package nftables
2+
3+
import (
4+
"encoding/binary"
5+
"fmt"
6+
"github.com/mdlayher/netlink"
7+
"golang.org/x/sys/unix"
8+
)
9+
10+
type GenMsg struct {
11+
ID uint32
12+
ProcPID uint32
13+
ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN
14+
}
15+
16+
var genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN)
17+
18+
func genFromMsg(msg netlink.Message) (*GenMsg, error) {
19+
if got, want := msg.Header.Type, genHeaderType; got != want {
20+
return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want)
21+
}
22+
ad, err := netlink.NewAttributeDecoder(msg.Data[4:])
23+
if err != nil {
24+
return nil, err
25+
}
26+
ad.ByteOrder = binary.BigEndian
27+
28+
msgOut := &GenMsg{}
29+
for ad.Next() {
30+
switch ad.Type() {
31+
case unix.NFTA_GEN_ID:
32+
msgOut.ID = ad.Uint32()
33+
case unix.NFTA_GEN_PROC_PID:
34+
msgOut.ProcPID = ad.Uint32()
35+
case unix.NFTA_GEN_PROC_NAME:
36+
msgOut.ProcComm = ad.String()
37+
default:
38+
return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes())
39+
}
40+
}
41+
if err := ad.Err(); err != nil {
42+
return nil, err
43+
}
44+
return msgOut, nil
45+
}

monitor.go

+90-32
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,15 @@ const (
116116
// nftables.MonitorEventTypeNewTable, you can access the corresponding table
117117
// details via Data.(*nftables.Table).
118118
type MonitorEvent struct {
119-
Type MonitorEventType
120-
Data any
121-
Error error
119+
Header netlink.Header
120+
Type MonitorEventType
121+
Data any
122+
Error error
123+
}
124+
125+
type MonitorEvents struct {
126+
GeneratedBy *MonitorEvent
127+
Changes []*MonitorEvent
122128
}
123129

124130
const (
@@ -139,15 +145,15 @@ type Monitor struct {
139145

140146
// mu covers eventCh and status
141147
mu sync.Mutex
142-
eventCh chan *MonitorEvent
148+
eventCh chan *MonitorEvents
143149
status int
144150
}
145151

146152
type MonitorOption func(*Monitor)
147153

148154
func WithMonitorEventBuffer(size int) MonitorOption {
149155
return func(monitor *Monitor) {
150-
monitor.eventCh = make(chan *MonitorEvent, size)
156+
monitor.eventCh = make(chan *MonitorEvents, size)
151157
}
152158
}
153159

@@ -177,7 +183,7 @@ func NewMonitor(opts ...MonitorOption) *Monitor {
177183
opt(monitor)
178184
}
179185
if monitor.eventCh == nil {
180-
monitor.eventCh = make(chan *MonitorEvent)
186+
monitor.eventCh = make(chan *MonitorEvents)
181187
}
182188
objects, ok := monitorFlags[monitor.action]
183189
if !ok {
@@ -192,20 +198,30 @@ func NewMonitor(opts ...MonitorOption) *Monitor {
192198
}
193199

194200
func (monitor *Monitor) monitor() {
201+
var changesEvents []*MonitorEvent
202+
195203
for {
196204
msgs, err := monitor.conn.Receive()
197205
if err != nil {
198206
if strings.Contains(err.Error(), "use of closed file") {
199207
// ignore the error that be closed
200208
break
201209
} else {
202-
// any other errors will be send to user, and then to close eventCh
210+
// any other errors will be sent to user, and then to close eventCh
203211
event := &MonitorEvent{
204212
Type: MonitorEventTypeOOB,
205213
Data: nil,
206214
Error: err,
207215
}
208-
monitor.eventCh <- event
216+
217+
changesEvents = append(changesEvents, event)
218+
219+
monitor.eventCh <- &MonitorEvents{
220+
GeneratedBy: event,
221+
Changes: changesEvents,
222+
}
223+
changesEvents = nil
224+
209225
break
210226
}
211227
}
@@ -221,54 +237,76 @@ func (monitor *Monitor) monitor() {
221237
case unix.NFT_MSG_NEWTABLE, unix.NFT_MSG_DELTABLE:
222238
table, err := tableFromMsg(msg)
223239
event := &MonitorEvent{
224-
Type: MonitorEventType(msgType),
225-
Data: table,
226-
Error: err,
240+
Type: MonitorEventType(msgType),
241+
Data: table,
242+
Error: err,
243+
Header: msg.Header,
227244
}
228-
monitor.eventCh <- event
245+
changesEvents = append(changesEvents, event)
229246
case unix.NFT_MSG_NEWCHAIN, unix.NFT_MSG_DELCHAIN:
230247
chain, err := chainFromMsg(msg)
231248
event := &MonitorEvent{
232-
Type: MonitorEventType(msgType),
233-
Data: chain,
234-
Error: err,
249+
Type: MonitorEventType(msgType),
250+
Data: chain,
251+
Error: err,
252+
Header: msg.Header,
235253
}
236-
monitor.eventCh <- event
254+
changesEvents = append(changesEvents, event)
237255
case unix.NFT_MSG_NEWRULE, unix.NFT_MSG_DELRULE:
238256
rule, err := parseRuleFromMsg(msg)
239257
event := &MonitorEvent{
240-
Type: MonitorEventType(msgType),
241-
Data: rule,
242-
Error: err,
258+
Type: MonitorEventType(msgType),
259+
Data: rule,
260+
Error: err,
261+
Header: msg.Header,
243262
}
244-
monitor.eventCh <- event
263+
changesEvents = append(changesEvents, event)
245264
case unix.NFT_MSG_NEWSET, unix.NFT_MSG_DELSET:
246265
set, err := setsFromMsg(msg)
247266
event := &MonitorEvent{
248-
Type: MonitorEventType(msgType),
249-
Data: set,
250-
Error: err,
267+
Type: MonitorEventType(msgType),
268+
Data: set,
269+
Error: err,
270+
Header: msg.Header,
251271
}
252-
monitor.eventCh <- event
272+
changesEvents = append(changesEvents, event)
253273
case unix.NFT_MSG_NEWSETELEM, unix.NFT_MSG_DELSETELEM:
254274
elems, err := elementsFromMsg(uint8(TableFamilyUnspecified), msg)
255275
event := &MonitorEvent{
256-
Type: MonitorEventType(msgType),
257-
Data: elems,
258-
Error: err,
276+
Type: MonitorEventType(msgType),
277+
Data: elems,
278+
Error: err,
279+
Header: msg.Header,
259280
}
260-
monitor.eventCh <- event
281+
changesEvents = append(changesEvents, event)
261282
case unix.NFT_MSG_NEWOBJ, unix.NFT_MSG_DELOBJ:
262283
obj, err := objFromMsg(msg, true)
263284
event := &MonitorEvent{
264-
Type: MonitorEventType(msgType),
265-
Data: obj,
266-
Error: err,
285+
Type: MonitorEventType(msgType),
286+
Data: obj,
287+
Error: err,
288+
Header: msg.Header,
289+
}
290+
changesEvents = append(changesEvents, event)
291+
case unix.NFT_MSG_NEWGEN:
292+
gen, err := genFromMsg(msg)
293+
event := &MonitorEvent{
294+
Type: MonitorEventType(msgType),
295+
Data: gen,
296+
Error: err,
297+
Header: msg.Header,
267298
}
268-
monitor.eventCh <- event
299+
300+
monitor.eventCh <- &MonitorEvents{
301+
GeneratedBy: event,
302+
Changes: changesEvents,
303+
}
304+
305+
changesEvents = nil
269306
}
270307
}
271308
}
309+
272310
monitor.mu.Lock()
273311
defer monitor.mu.Unlock()
274312

@@ -294,6 +332,26 @@ func (monitor *Monitor) Close() error {
294332
// Caller may receive a MonitorEventTypeOOB event which contains an error we didn't
295333
// handle, for now.
296334
func (cc *Conn) AddMonitor(monitor *Monitor) (chan *MonitorEvent, error) {
335+
generationalEventCh, err := cc.AddGenerationalMonitor(monitor)
336+
if err != nil {
337+
return nil, err
338+
}
339+
340+
eventCh := make(chan *MonitorEvent)
341+
342+
go func() {
343+
defer close(eventCh)
344+
for monitorEvents := range generationalEventCh {
345+
for _, event := range monitorEvents.Changes {
346+
eventCh <- event
347+
}
348+
}
349+
}()
350+
351+
return eventCh, nil
352+
}
353+
354+
func (cc *Conn) AddGenerationalMonitor(monitor *Monitor) (chan *MonitorEvents, error) {
297355
conn, closer, err := cc.netlinkConn()
298356
if err != nil {
299357
return nil, err

monitor_test.go

+45-23
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"fmt"
55
"log"
66
"net"
7+
"os"
8+
"path/filepath"
79
"sync"
810
"sync/atomic"
911
"testing"
@@ -21,17 +23,19 @@ func ExampleNewMonitor() {
2123

2224
mon := nftables.NewMonitor()
2325
defer mon.Close()
24-
events, err := conn.AddMonitor(mon)
26+
events, err := conn.AddGenerationalMonitor(mon)
2527
if err != nil {
2628
log.Fatal(err)
2729
}
2830
for ev := range events {
29-
log.Printf("ev: %+v, data = %T", ev, ev.Data)
30-
switch ev.Type {
31-
case nftables.MonitorEventTypeNewTable:
32-
log.Printf("data = %+v", ev.Data.(*nftables.Table))
31+
log.Printf("ev: %+v, data = %T", ev, ev.Changes)
3332

34-
// …more cases if needed…
33+
for _, change := range ev.Changes {
34+
switch change.Type {
35+
case nftables.MonitorEventTypeNewTable:
36+
log.Printf("data = %+v", change.Data.(*nftables.Table))
37+
// …more cases if needed…
38+
}
3539
}
3640
}
3741
}
@@ -44,10 +48,9 @@ func TestMonitor(t *testing.T) {
4448
// Clear all rules at the beginning + end of the test.
4549
c.FlushRuleset()
4650
defer c.FlushRuleset()
47-
4851
// default to monitor all
4952
monitor := nftables.NewMonitor()
50-
events, err := c.AddMonitor(monitor)
53+
events, err := c.AddGenerationalMonitor(monitor)
5154
if err != nil {
5255
t.Fatal(err)
5356
}
@@ -58,6 +61,7 @@ func TestMonitor(t *testing.T) {
5861
var gotRule *nftables.Rule
5962
wg := sync.WaitGroup{}
6063
wg.Add(1)
64+
var errMonitor error
6165
go func() {
6266
defer wg.Done()
6367
count := int32(0)
@@ -66,23 +70,35 @@ func TestMonitor(t *testing.T) {
6670
if !ok {
6771
return
6872
}
69-
if event.Error != nil {
70-
err = fmt.Errorf("monitor err: %s", event.Error)
73+
74+
genMsg := event.GeneratedBy.Data.(*nftables.GenMsg)
75+
fileName := filepath.Base(os.Args[0])
76+
77+
if genMsg.ProcComm != fileName {
78+
errMonitor = fmt.Errorf("procComm: %s, want: %s", genMsg.ProcComm, fileName)
7179
return
7280
}
73-
switch event.Type {
74-
case nftables.MonitorEventTypeNewTable:
75-
gotTable = event.Data.(*nftables.Table)
76-
atomic.AddInt32(&count, 1)
77-
case nftables.MonitorEventTypeNewChain:
78-
gotChain = event.Data.(*nftables.Chain)
79-
atomic.AddInt32(&count, 1)
80-
case nftables.MonitorEventTypeNewRule:
81-
gotRule = event.Data.(*nftables.Rule)
82-
atomic.AddInt32(&count, 1)
83-
}
84-
if atomic.LoadInt32(&count) == 3 {
85-
return
81+
82+
for _, change := range event.Changes {
83+
if change.Error != nil {
84+
errMonitor = fmt.Errorf("monitor err: %s", change.Error)
85+
return
86+
}
87+
88+
switch change.Type {
89+
case nftables.MonitorEventTypeNewTable:
90+
gotTable = change.Data.(*nftables.Table)
91+
atomic.AddInt32(&count, 1)
92+
case nftables.MonitorEventTypeNewChain:
93+
gotChain = change.Data.(*nftables.Chain)
94+
atomic.AddInt32(&count, 1)
95+
case nftables.MonitorEventTypeNewRule:
96+
gotRule = change.Data.(*nftables.Rule)
97+
atomic.AddInt32(&count, 1)
98+
}
99+
if atomic.LoadInt32(&count) == 3 {
100+
return
101+
}
86102
}
87103
}
88104
}()
@@ -126,7 +142,13 @@ func TestMonitor(t *testing.T) {
126142
if err := c.Flush(); err != nil {
127143
t.Fatal(err)
128144
}
145+
129146
wg.Wait()
147+
148+
if errMonitor != nil {
149+
t.Fatal("monitor err", errMonitor)
150+
}
151+
130152
if gotTable.Family != nat.Family || gotTable.Name != nat.Name {
131153
t.Fatal("no want table", gotTable.Family, gotTable.Name)
132154
}

0 commit comments

Comments
 (0)