diff --git a/conntrack_linux.go b/conntrack_linux.go index ff20869b..f4cdf081 100644 --- a/conntrack_linux.go +++ b/conntrack_linux.go @@ -85,6 +85,10 @@ func ConntrackDeleteFilters(table ConntrackTableType, family InetFamily, filters return pkgHandle.ConntrackDeleteFilters(table, family, filters...) } +func ConntrackTableListStream(table ConntrackTableType, family InetFamily, handle chan *ConntrackFlow) error { + return pkgHandle.ConntrackTableListStream(table, family, handle) +} + // ConntrackTableList returns the flow list of a table of a specific family using the netlink handle passed // conntrack -L [table] [options] List conntrack or expectation table // @@ -195,6 +199,17 @@ func (h *Handle) ConntrackDeleteFilters(table ConntrackTableType, family InetFam return matched, finalErr } +func (h *Handle) ConntrackTableListStream(table ConntrackTableType, family InetFamily, handle chan *ConntrackFlow) error { + req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_GET, unix.NLM_F_DUMP) + + err := req.ExecuteIter(unix.NETLINK_NETFILTER, 0, func(dataRaw []byte) bool { + handle <- parseRawData(dataRaw) + return true + }) + + return err +} + func (h *Handle) newConntrackRequest(table ConntrackTableType, family InetFamily, operation, flags int) *nl.NetlinkRequest { // Create the Netlink request object req := h.newNetlinkRequest((int(table)<<8)|operation, flags) @@ -221,9 +236,12 @@ type ProtoInfo interface { } // ProtoInfoTCP corresponds to the `tcp` struct of the __nfct_protoinfo union. -// Only TCP state is currently supported. type ProtoInfoTCP struct { - State uint8 + State uint8 + WsacleOriginal uint8 + WsacleReply uint8 + FlagsOriginal uint16 + FlagsReply uint16 } // Protocol returns "tcp". @@ -233,6 +251,14 @@ func (p *ProtoInfoTCP) toNlData() ([]*nl.RtAttr, error) { ctProtoInfoTCP := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_PROTOINFO_TCP, []byte{}) ctProtoInfoTCPState := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_STATE, nl.Uint8Attr(p.State)) ctProtoInfoTCP.AddChild(ctProtoInfoTCPState) + ctProtoInfoTCPWscaleOriginal := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_WSCALE_ORIGINAL, nl.Uint8Attr(p.WsacleOriginal)) + ctProtoInfoTCP.AddChild(ctProtoInfoTCPWscaleOriginal) + ctProtoInfoTCPWscaleReply := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_WSCALE_REPLY, nl.Uint8Attr(p.WsacleReply)) + ctProtoInfoTCP.AddChild(ctProtoInfoTCPWscaleReply) + ctProtoInfoTCPFlagsOriginal := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_FLAGS_ORIGINAL, nl.BEUint16Attr(p.FlagsOriginal)) + ctProtoInfoTCP.AddChild(ctProtoInfoTCPFlagsOriginal) + ctProtoInfoTCPFlagsReply := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_FLAGS_REPLY, nl.BEUint16Attr(p.FlagsReply)) + ctProtoInfoTCP.AddChild(ctProtoInfoTCPFlagsReply) ctProtoInfo.AddChild(ctProtoInfoTCP) return []*nl.RtAttr{ctProtoInfo}, nil @@ -261,6 +287,11 @@ type IPTuple struct { Protocol uint8 SrcIP net.IP SrcPort uint16 + + // ICMP only + ICMPID uint16 + ICMPType uint8 + ICMPCode uint8 } // toNlData generates the inner fields of a nested tuple netlink datastructure @@ -304,7 +335,11 @@ type ConntrackFlow struct { TimeStart uint64 TimeStop uint64 TimeOut uint32 + Status uint32 + Use uint32 + ID uint32 Labels []byte + LabelsMask []byte ProtoInfo ProtoInfo } @@ -315,19 +350,37 @@ func (s *ConntrackFlow) String() string { start := time.Unix(0, int64(s.TimeStart)) stop := time.Unix(0, int64(s.TimeStop)) timeout := int32(s.TimeOut) - res := fmt.Sprintf("%s\t%d src=%s dst=%s sport=%d dport=%d packets=%d bytes=%d\tsrc=%s dst=%s sport=%d dport=%d packets=%d bytes=%d mark=0x%x ", - nl.L4ProtoMap[s.Forward.Protocol], s.Forward.Protocol, - s.Forward.SrcIP.String(), s.Forward.DstIP.String(), s.Forward.SrcPort, s.Forward.DstPort, s.Forward.Packets, s.Forward.Bytes, - s.Reverse.SrcIP.String(), s.Reverse.DstIP.String(), s.Reverse.SrcPort, s.Reverse.DstPort, s.Reverse.Packets, s.Reverse.Bytes, - s.Mark) + + var out string + if s.Forward.Protocol == unix.IPPROTO_ICMP || s.Forward.Protocol == unix.IPPROTO_ICMPV6 { + out = fmt.Sprintf("%s\t%d src=%s dst=%s id=%d type=%d code=%d packets=%d bytes=%d\tsrc=%s dst=%s id=%d type=%d code=%d packets=%d bytes=%d", + nl.L4ProtoMap[s.Forward.Protocol], s.Forward.Protocol, + s.Forward.SrcIP.String(), s.Forward.DstIP.String(), s.Forward.ICMPID, s.Forward.ICMPType, s.Forward.ICMPCode, s.Forward.Packets, s.Forward.Bytes, + s.Reverse.SrcIP.String(), s.Reverse.DstIP.String(), s.Reverse.ICMPID, s.Reverse.ICMPType, s.Reverse.ICMPCode, s.Reverse.Packets, s.Reverse.Bytes) + } else { + out = fmt.Sprintf("%s\t%d src=%s dst=%s sport=%d dport=%d packets=%d bytes=%d\tsrc=%s dst=%s sport=%d dport=%d packets=%d bytes=%d", + nl.L4ProtoMap[s.Forward.Protocol], s.Forward.Protocol, + s.Forward.SrcIP.String(), s.Forward.DstIP.String(), s.Forward.SrcPort, s.Forward.DstPort, s.Forward.Packets, s.Forward.Bytes, + s.Reverse.SrcIP.String(), s.Reverse.DstIP.String(), s.Reverse.SrcPort, s.Reverse.DstPort, s.Reverse.Packets, s.Reverse.Bytes) + } + out += fmt.Sprintf(" mark=0x%x", s.Mark) if len(s.Labels) > 0 { - res += fmt.Sprintf("labels=0x%x ", s.Labels) + out += fmt.Sprintf(" labels=0x%x", s.Labels) + } + if len(s.LabelsMask) > 0 { + out += fmt.Sprintf("/0x%x", s.LabelsMask) + } + if s.Status != 0 { + out += fmt.Sprintf(" status=0x%x", s.Status) } if s.Zone != 0 { - res += fmt.Sprintf("zone=%d ", s.Zone) + out += fmt.Sprintf(" zone=%d", s.Zone) } - res += fmt.Sprintf("start=%v stop=%v timeout=%d(sec)", start, stop, timeout) - return res + if s.Use != 0 { + out += fmt.Sprintf(" use=0x%x", s.Use) + } + out += fmt.Sprintf(" start=%v stop=%v timeout=%d(sec)", start, stop, timeout) + return out } // toNlData generates netlink messages representing the flow. @@ -444,8 +497,8 @@ func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { if t == nl.CTA_PROTO_NUM { tpl.Protocol = uint8(v[0]) } - // We only parse TCP & UDP headers. Skip the others. - if tpl.Protocol != unix.IPPROTO_TCP && tpl.Protocol != unix.IPPROTO_UDP { + // We only parse TCP, UDP, ICMP, ICMPv6 headers. Skip the others. + if tpl.Protocol != unix.IPPROTO_TCP && tpl.Protocol != unix.IPPROTO_UDP && tpl.Protocol != unix.IPPROTO_ICMP && tpl.Protocol != unix.IPPROTO_ICMPV6 { // skip the rest bytesRemaining := protoInfoTotalLen - protoInfoBytesRead reader.Seek(int64(bytesRemaining), seekCurrent) @@ -454,7 +507,12 @@ func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { // Skip 3 bytes of padding reader.Seek(3, seekCurrent) protoInfoBytesRead += 3 - for i := 0; i < 2; i++ { + loopCount := 2 + if tpl.Protocol == unix.IPPROTO_ICMP || tpl.Protocol == unix.IPPROTO_ICMPV6 { + loopCount = 3 // ID, Type, Code + } + var ICMPCodeDone, ICMPTypeDone bool + for i := 0; i < loopCount; i++ { _, t, _ := parseNfAttrTL(reader) protoInfoBytesRead += uint16(nl.SizeofNfattr) switch t { @@ -464,6 +522,26 @@ func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 { case nl.CTA_PROTO_DST_PORT: parseBERaw16(reader, &tpl.DstPort) protoInfoBytesRead += 2 + case nl.CTA_PROTO_ICMP_ID: + fallthrough + case nl.CTA_PROTO_ICMPV6_ID: + parseBERaw16(reader, &tpl.ICMPID) + protoInfoBytesRead += 2 + case nl.CTA_PROTO_ICMP_CODE: + fallthrough + case nl.CTA_PROTO_ICMPV6_CODE: + parseU8(reader, &tpl.ICMPCode) + protoInfoBytesRead += 1 + ICMPCodeDone = true + case nl.CTA_PROTO_ICMP_TYPE: + fallthrough + case nl.CTA_PROTO_ICMPV6_TYPE: + parseU8(reader, &tpl.ICMPType) + protoInfoBytesRead += 1 + ICMPTypeDone = true + } + if (t == nl.CTA_PROTO_ICMP_CODE || t == nl.CTA_PROTO_ICMP_TYPE) && (!ICMPCodeDone || !ICMPTypeDone) { + continue } // Skip 2 bytes of padding reader.Seek(2, seekCurrent) @@ -503,6 +581,10 @@ func skipNfAttrValue(r *bytes.Reader, len uint16) uint16 { return len } +func parseU8(r *bytes.Reader, v *uint8) { + binary.Read(r, binary.BigEndian, v) +} + func parseBERaw16(r *bytes.Reader, v *uint16) { binary.Read(r, binary.BigEndian, v) } @@ -576,6 +658,22 @@ func parseProtoInfoTCP(r *bytes.Reader, attrLen uint16) *ProtoInfoTCP { case nl.CTA_PROTOINFO_TCP_STATE: p.State = parseProtoInfoTCPState(r) bytesRead += nl.SizeofNfattr + case nl.CTA_PROTOINFO_TCP_WSCALE_ORIGINAL: + parseU8(r, &p.WsacleOriginal) + r.Seek(nl.SizeofNfattr-1, seekCurrent) + bytesRead += nl.SizeofNfattr + case nl.CTA_PROTOINFO_TCP_WSCALE_REPLY: + parseU8(r, &p.WsacleReply) + r.Seek(nl.SizeofNfattr-1, seekCurrent) + bytesRead += nl.SizeofNfattr + case nl.CTA_PROTOINFO_TCP_FLAGS_ORIGINAL: + parseBERaw16(r, &p.FlagsOriginal) + r.Seek(nl.SizeofNfattr-2, seekCurrent) + bytesRead += nl.SizeofNfattr + case nl.CTA_PROTOINFO_TCP_FLAGS_REPLY: + parseBERaw16(r, &p.FlagsReply) + r.Seek(nl.SizeofNfattr-2, seekCurrent) + bytesRead += nl.SizeofNfattr default: bytesRead += int(skipNfAttrValue(r, l)) } @@ -679,14 +777,20 @@ func parseRawData(data []byte) *ConntrackFlow { switch t { case nl.CTA_MARK: s.Mark = parseConnectionMark(reader) + case nl.CTA_ZONE: + s.Zone = parseConnectionZone(reader) case nl.CTA_LABELS: s.Labels = parseConnectionLabels(reader) + case nl.CTA_LABELS_MASK: + s.LabelsMask = parseConnectionLabels(reader) case nl.CTA_TIMEOUT: s.TimeOut = parseTimeOut(reader) - case nl.CTA_ID, nl.CTA_STATUS, nl.CTA_USE: - skipNfAttrValue(reader, l) - case nl.CTA_ZONE: - s.Zone = parseConnectionZone(reader) + case nl.CTA_STATUS: + parseBERaw32(reader, &s.Status) + case nl.CTA_USE: + parseBERaw32(reader, &s.Use) + case nl.CTA_ID: + parseBERaw32(reader, &s.ID) default: skipNfAttrValue(reader, l) } diff --git a/conntrack_test.go b/conntrack_test.go index 48e5c4a1..978957ae 100644 --- a/conntrack_test.go +++ b/conntrack_test.go @@ -934,10 +934,14 @@ func TestParseRawData(t *testing.T) { 12, 0, 1, 0, 22, 134, 80, 142, 230, 127, 74, 166, /* >> CTA_LABELS */ - 20, 0, 22, 0, - 0, 0, 0, 0, 5, 0, 18, 172, 66, 2, 1, 0, 0, 0, 0, 0}, + 16, 0, 22, 0, + 34, 65, 12, 12, 91, 134, 145, 211, 123, 93, 13, 47, 95, 34, 15, 77, + /* >> CTA_LABELS_MASK */ + 16, 0, 23, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, expConntrackFlow: "udp\t17 src=192.168.0.10 dst=192.168.0.3 sport=48385 dport=53 packets=1 bytes=55\t" + - "src=192.168.0.3 dst=192.168.0.10 sport=53 dport=48385 packets=1 bytes=71 mark=0x5 labels=0x00000000050012ac4202010000000000 " + + "src=192.168.0.3 dst=192.168.0.10 sport=53 dport=48385 packets=1 bytes=71 mark=0x5 " + + "labels=0x22410c0c5b8691d37b5d0d2f5f220f4d/0xffffffffffffffffffffffffffffffff status=0x18a use=0x1 " + "start=2021-06-07 13:41:30.39632247 +0000 UTC stop=1970-01-01 00:00:00 +0000 UTC timeout=32(sec)", }, { @@ -1033,10 +1037,17 @@ func TestParseRawData(t *testing.T) { 16, 0, 20, 128, /* >>>> CTA_TIMESTAMP_START */ 12, 0, 1, 0, - 22, 134, 80, 175, 134, 10, 182, 221}, + 22, 134, 80, 175, 134, 10, 182, 221, + /* >> CTA_LABELS */ + 16, 0, 22, 0, + 34, 65, 12, 12, 91, 134, 145, 211, 123, 93, 13, 47, 95, 34, 15, 77, + /* >> CTA_LABELS_MASK */ + 16, 0, 23, 0, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}, expConntrackFlow: "tcp\t6 src=192.168.0.10 dst=192.168.77.73 sport=42625 dport=3333 packets=11 bytes=1914\t" + - "src=192.168.77.73 dst=192.168.0.10 sport=3333 dport=42625 packets=10 bytes=1858 mark=0x5 zone=100 " + - "start=2021-06-07 13:43:50.511990493 +0000 UTC stop=1970-01-01 00:00:00 +0000 UTC timeout=152(sec)", + "src=192.168.77.73 dst=192.168.0.10 sport=3333 dport=42625 packets=10 bytes=1858 mark=0x5 " + + "labels=0x22410c0c5b8691d37b5d0d2f5f220f4d/0xffffffffffffffffffffffffffffffff " + + "status=0x18e zone=100 use=0x1 start=2021-06-07 13:43:50.511990493 +0000 UTC stop=1970-01-01 00:00:00 +0000 UTC timeout=152(sec)", }, } diff --git a/nl/conntrack_linux.go b/nl/conntrack_linux.go index 6989d1ed..30b7786b 100644 --- a/nl/conntrack_linux.go +++ b/nl/conntrack_linux.go @@ -1,6 +1,10 @@ package nl -import "unsafe" +import ( + "unsafe" + + "golang.org/x/sys/unix" +) // Track the message sizes for the correct serialization/deserialization const ( @@ -11,11 +15,14 @@ const ( ) var L4ProtoMap = map[uint8]string{ - 6: "tcp", - 17: "udp", + unix.IPPROTO_ICMP: "icmp", + unix.IPPROTO_ICMPV6: "icmpv6", + unix.IPPROTO_TCP: "tcp", + unix.IPPROTO_UDP: "udp", } // From https://git.netfilter.org/libnetfilter_conntrack/tree/include/libnetfilter_conntrack/libnetfilter_conntrack_tcp.h +// // enum tcp_state { // TCP_CONNTRACK_NONE, // TCP_CONNTRACK_SYN_SENT, @@ -32,38 +39,38 @@ var L4ProtoMap = map[uint8]string{ // TCP_CONNTRACK_IGNORE // }; const ( - TCP_CONNTRACK_NONE = 0 - TCP_CONNTRACK_SYN_SENT = 1 - TCP_CONNTRACK_SYN_RECV = 2 - TCP_CONNTRACK_ESTABLISHED = 3 - TCP_CONNTRACK_FIN_WAIT = 4 - TCP_CONNTRACK_CLOSE_WAIT = 5 - TCP_CONNTRACK_LAST_ACK = 6 - TCP_CONNTRACK_TIME_WAIT = 7 - TCP_CONNTRACK_CLOSE = 8 - TCP_CONNTRACK_LISTEN = 9 - TCP_CONNTRACK_SYN_SENT2 = 9 - TCP_CONNTRACK_MAX = 10 - TCP_CONNTRACK_IGNORE = 11 + TCP_CONNTRACK_NONE = 0 + TCP_CONNTRACK_SYN_SENT = 1 + TCP_CONNTRACK_SYN_RECV = 2 + TCP_CONNTRACK_ESTABLISHED = 3 + TCP_CONNTRACK_FIN_WAIT = 4 + TCP_CONNTRACK_CLOSE_WAIT = 5 + TCP_CONNTRACK_LAST_ACK = 6 + TCP_CONNTRACK_TIME_WAIT = 7 + TCP_CONNTRACK_CLOSE = 8 + TCP_CONNTRACK_LISTEN = 9 + TCP_CONNTRACK_SYN_SENT2 = 9 + TCP_CONNTRACK_MAX = 10 + TCP_CONNTRACK_IGNORE = 11 ) // All the following constants are coming from: // https://github.com/torvalds/linux/blob/master/include/uapi/linux/netfilter/nfnetlink_conntrack.h -// enum cntl_msg_types { -// IPCTNL_MSG_CT_NEW, -// IPCTNL_MSG_CT_GET, -// IPCTNL_MSG_CT_DELETE, -// IPCTNL_MSG_CT_GET_CTRZERO, -// IPCTNL_MSG_CT_GET_STATS_CPU, -// IPCTNL_MSG_CT_GET_STATS, -// IPCTNL_MSG_CT_GET_DYING, -// IPCTNL_MSG_CT_GET_UNCONFIRMED, +// enum cntl_msg_types { +// IPCTNL_MSG_CT_NEW, +// IPCTNL_MSG_CT_GET, +// IPCTNL_MSG_CT_DELETE, +// IPCTNL_MSG_CT_GET_CTRZERO, +// IPCTNL_MSG_CT_GET_STATS_CPU, +// IPCTNL_MSG_CT_GET_STATS, +// IPCTNL_MSG_CT_GET_DYING, +// IPCTNL_MSG_CT_GET_UNCONFIRMED, // -// IPCTNL_MSG_MAX -// }; +// IPCTNL_MSG_MAX +// }; const ( - IPCTNL_MSG_CT_NEW = 0 + IPCTNL_MSG_CT_NEW = 0 IPCTNL_MSG_CT_GET = 1 IPCTNL_MSG_CT_DELETE = 2 ) @@ -80,36 +87,38 @@ const ( NLA_ALIGNTO uint16 = 4 // #define NLA_ALIGNTO 4 ) -// enum ctattr_type { -// CTA_UNSPEC, -// CTA_TUPLE_ORIG, -// CTA_TUPLE_REPLY, -// CTA_STATUS, -// CTA_PROTOINFO, -// CTA_HELP, -// CTA_NAT_SRC, +// enum ctattr_type { +// CTA_UNSPEC, +// CTA_TUPLE_ORIG, +// CTA_TUPLE_REPLY, +// CTA_STATUS, +// CTA_PROTOINFO, +// CTA_HELP, +// CTA_NAT_SRC, +// // #define CTA_NAT CTA_NAT_SRC /* backwards compatibility */ -// CTA_TIMEOUT, -// CTA_MARK, -// CTA_COUNTERS_ORIG, -// CTA_COUNTERS_REPLY, -// CTA_USE, -// CTA_ID, -// CTA_NAT_DST, -// CTA_TUPLE_MASTER, -// CTA_SEQ_ADJ_ORIG, -// CTA_NAT_SEQ_ADJ_ORIG = CTA_SEQ_ADJ_ORIG, -// CTA_SEQ_ADJ_REPLY, -// CTA_NAT_SEQ_ADJ_REPLY = CTA_SEQ_ADJ_REPLY, -// CTA_SECMARK, /* obsolete */ -// CTA_ZONE, -// CTA_SECCTX, -// CTA_TIMESTAMP, -// CTA_MARK_MASK, -// CTA_LABELS, -// CTA_LABELS_MASK, -// __CTA_MAX -// }; +// +// CTA_TIMEOUT, +// CTA_MARK, +// CTA_COUNTERS_ORIG, +// CTA_COUNTERS_REPLY, +// CTA_USE, +// CTA_ID, +// CTA_NAT_DST, +// CTA_TUPLE_MASTER, +// CTA_SEQ_ADJ_ORIG, +// CTA_NAT_SEQ_ADJ_ORIG = CTA_SEQ_ADJ_ORIG, +// CTA_SEQ_ADJ_REPLY, +// CTA_NAT_SEQ_ADJ_REPLY = CTA_SEQ_ADJ_REPLY, +// CTA_SECMARK, /* obsolete */ +// CTA_ZONE, +// CTA_SECCTX, +// CTA_TIMESTAMP, +// CTA_MARK_MASK, +// CTA_LABELS, +// CTA_LABELS_MASK, +// __CTA_MAX +// }; const ( CTA_TUPLE_ORIG = 1 CTA_TUPLE_REPLY = 2 @@ -127,27 +136,29 @@ const ( CTA_LABELS_MASK = 23 ) -// enum ctattr_tuple { -// CTA_TUPLE_UNSPEC, -// CTA_TUPLE_IP, -// CTA_TUPLE_PROTO, -// CTA_TUPLE_ZONE, -// __CTA_TUPLE_MAX -// }; +// enum ctattr_tuple { +// CTA_TUPLE_UNSPEC, +// CTA_TUPLE_IP, +// CTA_TUPLE_PROTO, +// CTA_TUPLE_ZONE, +// __CTA_TUPLE_MAX +// }; +// // #define CTA_TUPLE_MAX (__CTA_TUPLE_MAX - 1) const ( CTA_TUPLE_IP = 1 CTA_TUPLE_PROTO = 2 ) -// enum ctattr_ip { -// CTA_IP_UNSPEC, -// CTA_IP_V4_SRC, -// CTA_IP_V4_DST, -// CTA_IP_V6_SRC, -// CTA_IP_V6_DST, -// __CTA_IP_MAX -// }; +// enum ctattr_ip { +// CTA_IP_UNSPEC, +// CTA_IP_V4_SRC, +// CTA_IP_V4_DST, +// CTA_IP_V6_SRC, +// CTA_IP_V6_DST, +// __CTA_IP_MAX +// }; +// // #define CTA_IP_MAX (__CTA_IP_MAX - 1) const ( CTA_IP_V4_SRC = 1 @@ -156,50 +167,59 @@ const ( CTA_IP_V6_DST = 4 ) -// enum ctattr_l4proto { -// CTA_PROTO_UNSPEC, -// CTA_PROTO_NUM, -// CTA_PROTO_SRC_PORT, -// CTA_PROTO_DST_PORT, -// CTA_PROTO_ICMP_ID, -// CTA_PROTO_ICMP_TYPE, -// CTA_PROTO_ICMP_CODE, -// CTA_PROTO_ICMPV6_ID, -// CTA_PROTO_ICMPV6_TYPE, -// CTA_PROTO_ICMPV6_CODE, -// __CTA_PROTO_MAX -// }; +// enum ctattr_l4proto { +// CTA_PROTO_UNSPEC, +// CTA_PROTO_NUM, +// CTA_PROTO_SRC_PORT, +// CTA_PROTO_DST_PORT, +// CTA_PROTO_ICMP_ID, +// CTA_PROTO_ICMP_TYPE, +// CTA_PROTO_ICMP_CODE, +// CTA_PROTO_ICMPV6_ID, +// CTA_PROTO_ICMPV6_TYPE, +// CTA_PROTO_ICMPV6_CODE, +// __CTA_PROTO_MAX +// }; +// // #define CTA_PROTO_MAX (__CTA_PROTO_MAX - 1) const ( - CTA_PROTO_NUM = 1 - CTA_PROTO_SRC_PORT = 2 - CTA_PROTO_DST_PORT = 3 + CTA_PROTO_NUM = 1 + CTA_PROTO_SRC_PORT = 2 + CTA_PROTO_DST_PORT = 3 + CTA_PROTO_ICMP_ID = 4 + CTA_PROTO_ICMP_TYPE = 5 + CTA_PROTO_ICMP_CODE = 6 + CTA_PROTO_ICMPV6_ID = 7 + CTA_PROTO_ICMPV6_TYPE = 8 + CTA_PROTO_ICMPV6_CODE = 9 ) -// enum ctattr_protoinfo { -// CTA_PROTOINFO_UNSPEC, -// CTA_PROTOINFO_TCP, -// CTA_PROTOINFO_DCCP, -// CTA_PROTOINFO_SCTP, -// __CTA_PROTOINFO_MAX -// }; +// enum ctattr_protoinfo { +// CTA_PROTOINFO_UNSPEC, +// CTA_PROTOINFO_TCP, +// CTA_PROTOINFO_DCCP, +// CTA_PROTOINFO_SCTP, +// __CTA_PROTOINFO_MAX +// }; +// // #define CTA_PROTOINFO_MAX (__CTA_PROTOINFO_MAX - 1) const ( CTA_PROTOINFO_UNSPEC = 0 - CTA_PROTOINFO_TCP = 1 - CTA_PROTOINFO_DCCP = 2 - CTA_PROTOINFO_SCTP = 3 + CTA_PROTOINFO_TCP = 1 + CTA_PROTOINFO_DCCP = 2 + CTA_PROTOINFO_SCTP = 3 ) -// enum ctattr_protoinfo_tcp { -// CTA_PROTOINFO_TCP_UNSPEC, -// CTA_PROTOINFO_TCP_STATE, -// CTA_PROTOINFO_TCP_WSCALE_ORIGINAL, -// CTA_PROTOINFO_TCP_WSCALE_REPLY, -// CTA_PROTOINFO_TCP_FLAGS_ORIGINAL, -// CTA_PROTOINFO_TCP_FLAGS_REPLY, -// __CTA_PROTOINFO_TCP_MAX -// }; +// enum ctattr_protoinfo_tcp { +// CTA_PROTOINFO_TCP_UNSPEC, +// CTA_PROTOINFO_TCP_STATE, +// CTA_PROTOINFO_TCP_WSCALE_ORIGINAL, +// CTA_PROTOINFO_TCP_WSCALE_REPLY, +// CTA_PROTOINFO_TCP_FLAGS_ORIGINAL, +// CTA_PROTOINFO_TCP_FLAGS_REPLY, +// __CTA_PROTOINFO_TCP_MAX +// }; +// // #define CTA_PROTOINFO_TCP_MAX (__CTA_PROTOINFO_TCP_MAX - 1) const ( CTA_PROTOINFO_TCP_STATE = 1 @@ -209,15 +229,16 @@ const ( CTA_PROTOINFO_TCP_FLAGS_REPLY = 5 ) -// enum ctattr_counters { -// CTA_COUNTERS_UNSPEC, -// CTA_COUNTERS_PACKETS, /* 64bit counters */ -// CTA_COUNTERS_BYTES, /* 64bit counters */ -// CTA_COUNTERS32_PACKETS, /* old 32bit counters, unused */ -// CTA_COUNTERS32_BYTES, /* old 32bit counters, unused */ -// CTA_COUNTERS_PAD, -// __CTA_COUNTERS_M -// }; +// enum ctattr_counters { +// CTA_COUNTERS_UNSPEC, +// CTA_COUNTERS_PACKETS, /* 64bit counters */ +// CTA_COUNTERS_BYTES, /* 64bit counters */ +// CTA_COUNTERS32_PACKETS, /* old 32bit counters, unused */ +// CTA_COUNTERS32_BYTES, /* old 32bit counters, unused */ +// CTA_COUNTERS_PAD, +// __CTA_COUNTERS_M +// }; +// // #define CTA_COUNTERS_MAX (__CTA_COUNTERS_MAX - 1) const ( CTA_COUNTERS_PACKETS = 1 @@ -233,12 +254,14 @@ const ( ) // /* General form of address family dependent message. -// */ -// struct nfgenmsg { -// __u8 nfgen_family; /* AF_xxx */ -// __u8 version; /* nfnetlink version */ -// __be16 res_id; /* resource id */ -// }; +// +// */ +// +// struct nfgenmsg { +// __u8 nfgen_family; /* AF_xxx */ +// __u8 version; /* nfnetlink version */ +// __be16 res_id; /* resource id */ +// }; type Nfgenmsg struct { NfgenFamily uint8 Version uint8 diff --git a/rule.go b/rule.go index 9d74c7cd..8086d3f4 100644 --- a/rule.go +++ b/rule.go @@ -46,6 +46,12 @@ func (r Rule) String() string { r.Priority, from, to, r.Table, r.typeString()) } +// RuleUpdate is sent when a route changes - type is RTM_NEWRULE or RTM_DELRULE +type RuleUpdate struct { + Type uint16 + Rule +} + // NewRule return empty rules. func NewRule() *Rule { return &Rule{ diff --git a/rule_linux.go b/rule_linux.go index dba99147..65c1b59a 100644 --- a/rule_linux.go +++ b/rule_linux.go @@ -5,8 +5,10 @@ import ( "errors" "fmt" "net" + "syscall" "github.com/vishvananda/netlink/nl" + "github.com/vishvananda/netns" "golang.org/x/sys/unix" ) @@ -227,73 +229,11 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ( var res = make([]Rule, 0) for i := range msgs { - msg := nl.DeserializeRtMsg(msgs[i]) - attrs, err := nl.ParseRouteAttr(msgs[i][msg.Len():]) + rule, err := deserializeRule(msgs[i]) if err != nil { return nil, err } - rule := NewRule() - rule.Priority = 0 // The default priority from kernel - - rule.Invert = msg.Flags&FibRuleInvert > 0 - rule.Family = int(msg.Family) - rule.Tos = uint(msg.Tos) - - for j := range attrs { - switch attrs[j].Attr.Type { - case unix.RTA_TABLE: - rule.Table = int(native.Uint32(attrs[j].Value[0:4])) - case nl.FRA_SRC: - rule.Src = &net.IPNet{ - IP: attrs[j].Value, - Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)), - } - case nl.FRA_DST: - rule.Dst = &net.IPNet{ - IP: attrs[j].Value, - Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)), - } - case nl.FRA_FWMARK: - rule.Mark = native.Uint32(attrs[j].Value[0:4]) - case nl.FRA_FWMASK: - mask := native.Uint32(attrs[j].Value[0:4]) - rule.Mask = &mask - case nl.FRA_TUN_ID: - rule.TunID = uint(native.Uint64(attrs[j].Value[0:8])) - case nl.FRA_IIFNAME: - rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1]) - case nl.FRA_OIFNAME: - rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1]) - case nl.FRA_SUPPRESS_PREFIXLEN: - i := native.Uint32(attrs[j].Value[0:4]) - if i != 0xffffffff { - rule.SuppressPrefixlen = int(i) - } - case nl.FRA_SUPPRESS_IFGROUP: - i := native.Uint32(attrs[j].Value[0:4]) - if i != 0xffffffff { - rule.SuppressIfgroup = int(i) - } - case nl.FRA_FLOW: - rule.Flow = int(native.Uint32(attrs[j].Value[0:4])) - case nl.FRA_GOTO: - rule.Goto = int(native.Uint32(attrs[j].Value[0:4])) - case nl.FRA_PRIORITY: - rule.Priority = int(native.Uint32(attrs[j].Value[0:4])) - case nl.FRA_IP_PROTO: - rule.IPProto = int(native.Uint32(attrs[j].Value[0:4])) - case nl.FRA_DPORT_RANGE: - rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4])) - case nl.FRA_SPORT_RANGE: - rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4])) - case nl.FRA_UID_RANGE: - rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8])) - case nl.FRA_PROTOCOL: - rule.Protocol = uint8(attrs[j].Value[0]) - } - } - if filter != nil { switch { case filterMask&RT_FILTER_SRC != 0 && @@ -376,3 +316,189 @@ func (r Rule) typeString() string { return fmt.Sprintf("type(0x%x)", r.Type) } } + +// deserializeRule decodes a binary netlink message into a Rule struct +func deserializeRule(m []byte) (*Rule, error) { + msg := nl.DeserializeRtMsg(m) + attrs, err := nl.ParseRouteAttr(m[msg.Len():]) + if err != nil { + return nil, err + } + + rule := NewRule() + rule.Priority = 0 // The default priority from kernel + + rule.Invert = msg.Flags&FibRuleInvert > 0 + rule.Family = int(msg.Family) + rule.Tos = uint(msg.Tos) + + for j := range attrs { + switch attrs[j].Attr.Type { + case unix.RTA_TABLE: + rule.Table = int(native.Uint32(attrs[j].Value[0:4])) + case nl.FRA_SRC: + rule.Src = &net.IPNet{ + IP: attrs[j].Value, + Mask: net.CIDRMask(int(msg.Src_len), 8*len(attrs[j].Value)), + } + case nl.FRA_DST: + rule.Dst = &net.IPNet{ + IP: attrs[j].Value, + Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)), + } + case nl.FRA_FWMARK: + rule.Mark = native.Uint32(attrs[j].Value[0:4]) + case nl.FRA_FWMASK: + mask := native.Uint32(attrs[j].Value[0:4]) + rule.Mask = &mask + case nl.FRA_TUN_ID: + rule.TunID = uint(native.Uint64(attrs[j].Value[0:8])) + case nl.FRA_IIFNAME: + rule.IifName = string(attrs[j].Value[:len(attrs[j].Value)-1]) + case nl.FRA_OIFNAME: + rule.OifName = string(attrs[j].Value[:len(attrs[j].Value)-1]) + case nl.FRA_SUPPRESS_PREFIXLEN: + i := native.Uint32(attrs[j].Value[0:4]) + if i != 0xffffffff { + rule.SuppressPrefixlen = int(i) + } + case nl.FRA_SUPPRESS_IFGROUP: + i := native.Uint32(attrs[j].Value[0:4]) + if i != 0xffffffff { + rule.SuppressIfgroup = int(i) + } + case nl.FRA_FLOW: + rule.Flow = int(native.Uint32(attrs[j].Value[0:4])) + case nl.FRA_GOTO: + rule.Goto = int(native.Uint32(attrs[j].Value[0:4])) + case nl.FRA_PRIORITY: + rule.Priority = int(native.Uint32(attrs[j].Value[0:4])) + case nl.FRA_IP_PROTO: + rule.IPProto = int(native.Uint32(attrs[j].Value[0:4])) + case nl.FRA_DPORT_RANGE: + rule.Dport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4])) + case nl.FRA_SPORT_RANGE: + rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4])) + case nl.FRA_UID_RANGE: + rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8])) + case nl.FRA_PROTOCOL: + rule.Protocol = uint8(attrs[j].Value[0]) + } + } + + return rule, nil +} + +// RuleSubscribe takes a chan down which notifications will be sent +// when rules are added or deleted. Close the 'done' chan to stop subscription. +func RuleSubscribe(ch chan<- RuleUpdate, done <-chan struct{}) error { + return ruleSubscribeAt(netns.None(), netns.None(), ch, done, nil, false, 0, nil, false) +} + +// RuleSubscribeAt works like RuleSubscribe plus it allows the caller +// to choose the network namespace in which to subscribe (ns). +func RuleSubscribeAt(ns netns.NsHandle, ch chan<- RuleUpdate, done <-chan struct{}) error { + return ruleSubscribeAt(ns, netns.None(), ch, done, nil, false, 0, nil, false) +} + +// RuleSubscribeOptions contains a set of options to use with +// RuleSubscribeWithOptions. +type RuleSubscribeOptions struct { + Namespace *netns.NsHandle + ErrorCallback func(error) + ListExisting bool + ReceiveBufferSize int + ReceiveBufferForceSize bool + ReceiveTimeout *unix.Timeval +} + +// RuleSubscribeWithOptions work like RuleSubscribe but enable to +// provide additional options to modify the behavior. Currently, the +// namespace can be provided as well as an error callback. +func RuleSubscribeWithOptions(ch chan<- RuleUpdate, done <-chan struct{}, options RuleSubscribeOptions) error { + if options.Namespace == nil { + none := netns.None() + options.Namespace = &none + } + return ruleSubscribeAt(*options.Namespace, netns.None(), ch, done, options.ErrorCallback, options.ListExisting, + options.ReceiveBufferSize, options.ReceiveTimeout, options.ReceiveBufferForceSize) +} + +func ruleSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- RuleUpdate, done <-chan struct{}, cberr func(error), listExisting bool, + rcvbuf int, rcvTimeout *unix.Timeval, rcvbufForce bool) error { + s, err := nl.SubscribeAt(newNs, curNs, unix.NETLINK_ROUTE, unix.RTNLGRP_IPV4_RULE, unix.RTNLGRP_IPV6_RULE) + if err != nil { + return err + } + if rcvTimeout != nil { + if err := s.SetReceiveTimeout(rcvTimeout); err != nil { + return err + } + } + if rcvbuf != 0 { + err = s.SetReceiveBufferSize(rcvbuf, rcvbufForce) + if err != nil { + return err + } + } + if done != nil { + go func() { + <-done + s.Close() + }() + } + if listExisting { + req := pkgHandle.newNetlinkRequest(unix.RTM_GETRULE, + unix.NLM_F_DUMP) + infmsg := nl.NewIfInfomsg(unix.AF_UNSPEC) + req.AddData(infmsg) + if err := s.Send(req); err != nil { + return err + } + } + go func() { + defer close(ch) + for { + msgs, from, err := s.Receive() + if err != nil { + if cberr != nil { + cberr(fmt.Errorf("Receive failed: %v", + err)) + } + return + } + if from.Pid != nl.PidKernel { + if cberr != nil { + cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)) + } + continue + } + for _, m := range msgs { + if m.Header.Type == unix.NLMSG_DONE { + continue + } + if m.Header.Type == unix.NLMSG_ERROR { + error := int32(native.Uint32(m.Data[0:4])) + if error == 0 { + continue + } + if cberr != nil { + cberr(fmt.Errorf("error message: %v", + syscall.Errno(-error))) + } + continue + } + rule, err := deserializeRule(m.Data) + if err != nil { + if cberr != nil { + cberr(err) + } + continue + } + ch <- RuleUpdate{Type: m.Header.Type, Rule: *rule} + } + } + }() + + return nil +} diff --git a/rule_test.go b/rule_test.go index 67338bce..fe011f92 100644 --- a/rule_test.go +++ b/rule_test.go @@ -4,10 +4,14 @@ package netlink import ( + "encoding/json" + "fmt" + "math/rand" "net" "testing" "time" + "github.com/vishvananda/netns" "golang.org/x/sys/unix" ) @@ -695,3 +699,253 @@ func ruleEquals(a, b Rule) bool { (ptrEqual(a.Mask, b.Mask) || (a.Mark != 0 && (a.Mask == nil && *b.Mask == 0xFFFFFFFF || b.Mask == nil && *a.Mask == 0xFFFFFFFF))) } + +// expectRuleUpdate returns whether the expected updated is received within one minute. +func expectRuleUpdate(ch <-chan RuleUpdate, t uint16, match func(Rule) bool) ([]Rule, bool) { + for { + receivedRules := []Rule{} + timeout := time.After(time.Minute) + select { + case update := <-ch: + j, _ := json.Marshal(update) + fmt.Printf("update: %s\n", string(j)) + if update.Type == t && match(update.Rule) { + receivedRules = append(receivedRules, update.Rule) + return receivedRules, true + } + case <-timeout: + return receivedRules, false + } + } +} + +func TestRuleSubscribe(t *testing.T) { + tearDown := setUpNetlinkTest(t) + defer tearDown() + + ch := make(chan RuleUpdate) + done := make(chan struct{}) + defer close(done) + var lastError error + defer func() { + if lastError != nil { + t.Fatalf("Fatal error received during subscription: %v", lastError) + } + }() + if err := RuleSubscribe(ch, done); err != nil { + lastError = err + } + + srcNet := &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)} + dstNet := &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)} + + rule := NewRule() + rule.Family = FAMILY_V4 + rule.Table = unix.RT_TABLE_MAIN + rule.Src = srcNet + rule.Dst = dstNet + rule.Priority = 5 + rule.OifName = "lo" + rule.IifName = "lo" + rule.Invert = true + rule.Tos = 0x10 + rule.Dport = NewRulePortRange(80, 80) + rule.Sport = NewRulePortRange(1000, 1024) + rule.IPProto = unix.IPPROTO_UDP + rule.UIDRange = NewRuleUIDRange(100, 100) + rule.Protocol = unix.RTPROT_KERNEL + + match := func(r Rule) bool { + return r.Src.String() == srcNet.String() && + r.Dst.String() == dstNet.String() && + r.Priority == 5 && + r.OifName == "lo" && + r.IifName == "lo" && + r.Invert == true && + r.Tos == 0x10 && + r.Dport.Start == 80 && + r.Dport.End == 80 && + r.Sport.Start == 1000 && + r.Sport.End == 1024 && + r.IPProto == unix.IPPROTO_UDP && + r.UIDRange.Start == 100 && + r.UIDRange.End == 100 && + r.Protocol == unix.RTPROT_KERNEL + } + + if err := RuleAdd(rule); err != nil { + t.Fatal(err) + } + + receivedRules, ok := expectRuleUpdate(ch, unix.RTM_NEWRULE, match) + if !ok { + t.Fatal("Add update not received as expected", receivedRules) + } + + if err := RuleDel(rule); err != nil { + t.Fatal(err) + } + receivedRules, ok = expectRuleUpdate(ch, unix.RTM_DELRULE, match) + if !ok { + t.Fatal("Del update not received as expected", receivedRules) + } +} + +func TestRuleSubscribeWithOptions(t *testing.T) { + tearDown := setUpNetlinkTest(t) + defer tearDown() + + ch := make(chan RuleUpdate, 10) + done := make(chan struct{}) + defer close(done) + var lastError error + defer func() { + if lastError != nil { + t.Fatalf("Fatal error received during subscription: %v", lastError) + } + }() + if err := RuleSubscribeWithOptions(ch, done, RuleSubscribeOptions{ + ErrorCallback: func(err error) { lastError = err }, + }); err != nil { + t.Fatal(err) + } + + rule := NewRule() + rule.Table = rand.Intn(1000) + 10000 + rule.Priority = rand.Intn(1000) + 10000 + rule.IifName = "lo" + + match := func(r Rule) bool { + return r.Table == rule.Table && + r.Priority == rule.Priority && + r.IifName == "lo" + } + + if err := RuleAdd(rule); err != nil { + t.Fatal(err) + } + receivedRules, ok := expectRuleUpdate(ch, unix.RTM_NEWRULE, match) + if !ok { + t.Fatal("Add update not received as expected", receivedRules) + } + + if err := RuleDel(rule); err != nil { + t.Fatal(err) + } + receivedRules, ok = expectRuleUpdate(ch, unix.RTM_DELRULE, match) + if !ok { + t.Fatal("Del update not received as expected", receivedRules) + } +} + +func TestRuleSubscribeAt(t *testing.T) { + skipUnlessRoot(t) + + // Create an handle on a custom netns + newNs, err := netns.New() + if err != nil { + t.Fatal(err) + } + defer newNs.Close() + + nh, err := NewHandleAt(newNs) + if err != nil { + t.Fatal(err) + } + defer nh.Close() + + ch := make(chan RuleUpdate) + done := make(chan struct{}) + defer close(done) + var lastError error + defer func() { + if lastError != nil { + t.Fatalf("Fatal error received during subscription: %v", lastError) + } + }() + if err := RuleSubscribeAt(newNs, ch, done); err != nil { + lastError = err + t.Fatal(err) + } + + rule := NewRule() + rule.Table = rand.Intn(1000) + 10000 + rule.Priority = rand.Intn(1000) + 10000 + rule.IifName = "lo" + + match := func(r Rule) bool { + return r.Table == rule.Table && + r.Priority == rule.Priority && + r.IifName == "lo" + } + + if err := nh.RuleAdd(rule); err != nil { + t.Fatal(err) + } + receivedRules, ok := expectRuleUpdate(ch, unix.RTM_NEWRULE, match) + if !ok { + t.Fatal("Add update not received as expected", receivedRules) + } + + if err := nh.RuleDel(rule); err != nil { + t.Fatal(err) + } + receivedRules, ok = expectRuleUpdate(ch, unix.RTM_DELRULE, match) + if !ok { + t.Fatal("Del update not received as expected", receivedRules) + } +} + +func TestRuleSubscribeListExisting(t *testing.T) { + skipUnlessRoot(t) + + // Create an handle on a custom netns + newNs, err := netns.New() + if err != nil { + t.Fatal(err) + } + defer newNs.Close() + + nh, err := NewHandleAt(newNs) + if err != nil { + t.Fatal(err) + } + defer nh.Close() + + rule := NewRule() + rule.Table = rand.Intn(1000) + 10000 + rule.Priority = rand.Intn(1000) + 10000 + rule.IifName = "lo" + + if err := nh.RuleAdd(rule); err != nil { + t.Fatal(err) + } + + ch := make(chan RuleUpdate) + done := make(chan struct{}) + defer close(done) + if err := RuleSubscribeWithOptions(ch, done, RuleSubscribeOptions{ + Namespace: &newNs, + ListExisting: true, + }); err != nil { + t.Fatal(err) + } + + match := func(r Rule) bool { + return r.Table == rule.Table && + r.Priority == rule.Priority && + r.IifName == "lo" + } + + receivedRules, ok := expectRuleUpdate(ch, unix.RTM_NEWRULE, match) + if !ok { + t.Fatal("Add update not received as expected", receivedRules) + } + if err := nh.RuleDel(rule); err != nil { + t.Fatal(err) + } + receivedRules, ok = expectRuleUpdate(ch, unix.RTM_DELRULE, match) + if !ok { + t.Fatal("Del update not received as expected", receivedRules) + } +}