diff --git a/include/net/proto_memory.h b/include/net/proto_memory.h index 8e91a8fa31b5..8e8432b13515 100644 --- a/include/net/proto_memory.h +++ b/include/net/proto_memory.h @@ -31,13 +31,22 @@ static inline bool sk_under_memory_pressure(const struct sock *sk) if (!sk->sk_prot->memory_pressure) return false; - if (mem_cgroup_sk_enabled(sk) && - mem_cgroup_sk_under_memory_pressure(sk)) - return true; + if (mem_cgroup_sk_enabled(sk)) { + if (mem_cgroup_sk_under_memory_pressure(sk)) + return true; + + if (mem_cgroup_sk_isolated(sk)) + return false; + } return !!READ_ONCE(*sk->sk_prot->memory_pressure); } +static inline bool sk_should_enter_memory_pressure(struct sock *sk) +{ + return !mem_cgroup_sk_enabled(sk) || !mem_cgroup_sk_isolated(sk); +} + static inline long proto_memory_allocated(const struct proto *prot) { diff --git a/include/net/sock.h b/include/net/sock.h index 63a6a48afb48..703cb9116c6e 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -2596,10 +2596,41 @@ static inline gfp_t gfp_memcg_charge(void) return in_softirq() ? GFP_ATOMIC : GFP_KERNEL; } +#define SK_BPF_MEMCG_FLAG_MASK (SK_BPF_MEMCG_FLAG_MAX - 1) +#define SK_BPF_MEMCG_PTR_MASK ~SK_BPF_MEMCG_FLAG_MASK + #ifdef CONFIG_MEMCG static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk) { +#ifdef CONFIG_CGROUP_BPF + unsigned long val = (unsigned long)sk->sk_memcg; + + val &= SK_BPF_MEMCG_PTR_MASK; + return (struct mem_cgroup *)val; +#else return sk->sk_memcg; +#endif +} + +static inline void mem_cgroup_sk_set_flags(struct sock *sk, unsigned short flags) +{ +#ifdef CONFIG_CGROUP_BPF + unsigned long val = (unsigned long)mem_cgroup_from_sk(sk); + + val |= flags; + sk->sk_memcg = (struct mem_cgroup *)val; +#endif +} + +static inline unsigned short mem_cgroup_sk_get_flags(const struct sock *sk) +{ +#ifdef CONFIG_CGROUP_BPF + unsigned long val = (unsigned long)sk->sk_memcg; + + return val & SK_BPF_MEMCG_FLAG_MASK; +#else + return 0; +#endif } static inline bool mem_cgroup_sk_enabled(const struct sock *sk) @@ -2607,6 +2638,11 @@ static inline bool mem_cgroup_sk_enabled(const struct sock *sk) return mem_cgroup_sockets_enabled && mem_cgroup_from_sk(sk); } +static inline bool mem_cgroup_sk_isolated(const struct sock *sk) +{ + return mem_cgroup_sk_get_flags(sk) & SK_BPF_MEMCG_SOCK_ISOLATED; +} + static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk) { struct mem_cgroup *memcg = mem_cgroup_from_sk(sk); @@ -2629,11 +2665,25 @@ static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk) return NULL; } +static inline void mem_cgroup_sk_set_flags(struct sock *sk, unsigned short flags) +{ +} + +static inline unsigned short mem_cgroup_sk_get_flags(const struct sock *sk) +{ + return 0; +} + static inline bool mem_cgroup_sk_enabled(const struct sock *sk) { return false; } +static inline bool mem_cgroup_sk_isolated(const struct sock *sk) +{ + return false; +} + static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk) { return false; diff --git a/include/net/tcp.h b/include/net/tcp.h index 2936b8175950..0191a4585bba 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -275,9 +275,13 @@ extern unsigned long tcp_memory_pressure; /* optimized version of sk_under_memory_pressure() for TCP sockets */ static inline bool tcp_under_memory_pressure(const struct sock *sk) { - if (mem_cgroup_sk_enabled(sk) && - mem_cgroup_sk_under_memory_pressure(sk)) - return true; + if (mem_cgroup_sk_enabled(sk)) { + if (mem_cgroup_sk_under_memory_pressure(sk)) + return true; + + if (mem_cgroup_sk_isolated(sk)) + return false; + } return READ_ONCE(tcp_memory_pressure); } diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h index 233de8677382..52b8c2278589 100644 --- a/include/uapi/linux/bpf.h +++ b/include/uapi/linux/bpf.h @@ -7182,6 +7182,7 @@ enum { TCP_BPF_SYN_MAC = 1007, /* Copy the MAC, IP[46], and TCP header */ TCP_BPF_SOCK_OPS_CB_FLAGS = 1008, /* Get or Set TCP sock ops flags */ SK_BPF_CB_FLAGS = 1009, /* Get or set sock ops flags in socket */ + SK_BPF_MEMCG_FLAGS = 1010, /* Get or Set flags saved in sk->sk_memcg */ }; enum { @@ -7204,6 +7205,11 @@ enum { */ }; +enum { + SK_BPF_MEMCG_SOCK_ISOLATED = (1UL << 0), + SK_BPF_MEMCG_FLAG_MAX = (1UL << 1), +}; + struct bpf_perf_event_value { __u64 counter; __u64 enabled; diff --git a/net/core/filter.c b/net/core/filter.c index 63f3baee2daf..eb2f87a732ef 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -5267,6 +5267,35 @@ static int sk_bpf_set_get_cb_flags(struct sock *sk, char *optval, bool getopt) return 0; } +static int sk_bpf_get_memcg_flags(struct sock *sk, char *optval) +{ + if (!sk_has_account(sk)) + return -EOPNOTSUPP; + + *(u32 *)optval = mem_cgroup_sk_get_flags(sk); + + return 0; +} + +static int sk_bpf_set_memcg_flags(struct sock *sk, char *optval, int optlen) +{ + u32 flags; + + if (optlen != sizeof(u32)) + return -EINVAL; + + if (!sk_has_account(sk)) + return -EOPNOTSUPP; + + flags = *(u32 *)optval; + if (flags >= SK_BPF_MEMCG_FLAG_MAX) + return -EINVAL; + + mem_cgroup_sk_set_flags(sk, flags); + + return 0; +} + static int sol_socket_sockopt(struct sock *sk, int optname, char *optval, int *optlen, bool getopt) @@ -5284,6 +5313,7 @@ static int sol_socket_sockopt(struct sock *sk, int optname, case SO_BINDTOIFINDEX: case SO_TXREHASH: case SK_BPF_CB_FLAGS: + case SK_BPF_MEMCG_FLAGS: if (*optlen != sizeof(int)) return -EINVAL; break; @@ -5293,8 +5323,15 @@ static int sol_socket_sockopt(struct sock *sk, int optname, return -EINVAL; } - if (optname == SK_BPF_CB_FLAGS) + switch (optname) { + case SK_BPF_CB_FLAGS: return sk_bpf_set_get_cb_flags(sk, optval, getopt); + case SK_BPF_MEMCG_FLAGS: + if (!IS_ENABLED(CONFIG_MEMCG) || !getopt) + return -EOPNOTSUPP; + + return sk_bpf_get_memcg_flags(sk, optval); + } if (getopt) { if (optname == SO_BINDTODEVICE) @@ -5723,6 +5760,44 @@ static const struct bpf_func_proto bpf_sock_addr_getsockopt_proto = { .arg5_type = ARG_CONST_SIZE, }; +BPF_CALL_5(bpf_unlocked_sock_setsockopt, struct sock *, sk, int, level, + int, optname, char *, optval, int, optlen) +{ + if (IS_ENABLED(CONFIG_MEMCG) && + level == SOL_SOCKET && optname == SK_BPF_MEMCG_FLAGS) + return sk_bpf_set_memcg_flags(sk, optval, optlen); + + return __bpf_setsockopt(sk, level, optname, optval, optlen); +} + +static const struct bpf_func_proto bpf_unlocked_sock_setsockopt_proto = { + .func = bpf_unlocked_sock_setsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_unlocked_sock_getsockopt, struct sock *, sk, int, level, + int, optname, char *, optval, int, optlen) +{ + return __bpf_getsockopt(sk, level, optname, optval, optlen); +} + +static const struct bpf_func_proto bpf_unlocked_sock_getsockopt_proto = { + .func = bpf_unlocked_sock_getsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_UNINIT_MEM, + .arg5_type = ARG_CONST_SIZE, +}; + BPF_CALL_5(bpf_sock_ops_setsockopt, struct bpf_sock_ops_kern *, bpf_sock, int, level, int, optname, char *, optval, int, optlen) { @@ -8051,6 +8126,20 @@ sock_filter_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_sk_storage_get_cg_sock_proto; case BPF_FUNC_ktime_get_coarse_ns: return &bpf_ktime_get_coarse_ns_proto; + case BPF_FUNC_setsockopt: + switch (prog->expected_attach_type) { + case BPF_CGROUP_INET_SOCK_CREATE: + return &bpf_unlocked_sock_setsockopt_proto; + default: + return NULL; + } + case BPF_FUNC_getsockopt: + switch (prog->expected_attach_type) { + case BPF_CGROUP_INET_SOCK_CREATE: + return &bpf_unlocked_sock_getsockopt_proto; + default: + return NULL; + } default: return bpf_base_func_proto(func_id, prog); } diff --git a/net/core/sock.c b/net/core/sock.c index 8002ac6293dc..bda1655fa9bf 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -1046,17 +1046,21 @@ static int sock_reserve_memory(struct sock *sk, int bytes) if (!charged) return -ENOMEM; - /* pre-charge to forward_alloc */ - sk_memory_allocated_add(sk, pages); - allocated = sk_memory_allocated(sk); - /* If the system goes into memory pressure with this - * precharge, give up and return error. - */ - if (allocated > sk_prot_mem_limits(sk, 1)) { - sk_memory_allocated_sub(sk, pages); - mem_cgroup_sk_uncharge(sk, pages); - return -ENOMEM; + if (!mem_cgroup_sk_isolated(sk)) { + /* pre-charge to forward_alloc */ + sk_memory_allocated_add(sk, pages); + allocated = sk_memory_allocated(sk); + + /* If the system goes into memory pressure with this + * precharge, give up and return error. + */ + if (allocated > sk_prot_mem_limits(sk, 1)) { + sk_memory_allocated_sub(sk, pages); + mem_cgroup_sk_uncharge(sk, pages); + return -ENOMEM; + } } + sk_forward_alloc_add(sk, pages << PAGE_SHIFT); WRITE_ONCE(sk->sk_reserved_mem, @@ -2515,6 +2519,7 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) #ifdef CONFIG_MEMCG /* sk->sk_memcg will be populated at accept() time */ newsk->sk_memcg = NULL; + mem_cgroup_sk_set_flags(newsk, mem_cgroup_sk_get_flags(sk)); #endif cgroup_sk_clone(&newsk->sk_cgrp_data); @@ -3153,8 +3158,11 @@ bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag) if (likely(skb_page_frag_refill(32U, pfrag, sk->sk_allocation))) return true; - sk_enter_memory_pressure(sk); + if (sk_should_enter_memory_pressure(sk)) + sk_enter_memory_pressure(sk); + sk_stream_moderate_sndbuf(sk); + return false; } EXPORT_SYMBOL(sk_page_frag_refill); @@ -3267,18 +3275,30 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind) { bool memcg_enabled = false, charged = false; struct proto *prot = sk->sk_prot; - long allocated; - - sk_memory_allocated_add(sk, amt); - allocated = sk_memory_allocated(sk); + long allocated = 0; if (mem_cgroup_sk_enabled(sk)) { + bool isolated = mem_cgroup_sk_isolated(sk); + memcg_enabled = true; charged = mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge()); - if (!charged) + + if (isolated && charged) + return 1; + + if (!charged) { + if (!isolated) { + sk_memory_allocated_add(sk, amt); + allocated = sk_memory_allocated(sk); + } + goto suppress_allocation; + } } + sk_memory_allocated_add(sk, amt); + allocated = sk_memory_allocated(sk); + /* Under limit. */ if (allocated <= sk_prot_mem_limits(sk, 0)) { sk_leave_memory_pressure(sk); @@ -3357,7 +3377,8 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind) trace_sock_exceed_buf_limit(sk, prot, allocated, kind); - sk_memory_allocated_sub(sk, amt); + if (allocated) + sk_memory_allocated_sub(sk, amt); if (charged) mem_cgroup_sk_uncharge(sk, amt); @@ -3396,11 +3417,15 @@ EXPORT_SYMBOL(__sk_mem_schedule); */ void __sk_mem_reduce_allocated(struct sock *sk, int amount) { - sk_memory_allocated_sub(sk, amount); - - if (mem_cgroup_sk_enabled(sk)) + if (mem_cgroup_sk_enabled(sk)) { mem_cgroup_sk_uncharge(sk, amount); + if (mem_cgroup_sk_isolated(sk)) + return; + } + + sk_memory_allocated_sub(sk, amount); + if (sk_under_global_memory_pressure(sk) && (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0))) sk_leave_memory_pressure(sk); diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 76e38092cd8a..adbc8bcb760b 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -95,6 +95,7 @@ #include #include #include +#include #include #include #include @@ -753,6 +754,42 @@ EXPORT_SYMBOL(inet_stream_connect); void __inet_accept(struct socket *sock, struct socket *newsock, struct sock *newsk) { + /* TODO: use sk_clone_lock() in SCTP and remove protocol checks */ + if (mem_cgroup_sockets_enabled && + (!IS_ENABLED(CONFIG_IP_SCTP) || + sk_is_tcp(newsk) || sk_is_mptcp(newsk))) { + gfp_t gfp = GFP_KERNEL | __GFP_NOFAIL; + unsigned short flags; + + flags = mem_cgroup_sk_get_flags(newsk); + mem_cgroup_sk_alloc(newsk); + + if (mem_cgroup_from_sk(newsk)) { + int amt; + + mem_cgroup_sk_set_flags(newsk, flags); + + /* The socket has not been accepted yet, no need + * to look at newsk->sk_wmem_queued. + */ + amt = sk_mem_pages(newsk->sk_forward_alloc + + atomic_read(&newsk->sk_rmem_alloc)); + if (amt) { + /* This amt is already charged globally to + * sk_prot->memory_allocated due to lack of + * sk_memcg until accept(), thus we need to + * reclaim it here if newsk is isolated. + */ + if (mem_cgroup_sk_isolated(newsk)) + sk_memory_allocated_sub(newsk, amt); + + mem_cgroup_sk_charge(newsk, amt, gfp); + } + } + + kmem_cache_charge(newsk, gfp); + } + sock_rps_record_flow(newsk); WARN_ON(!((1 << newsk->sk_state) & (TCPF_ESTABLISHED | TCPF_SYN_RECV | diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index 0ef1eacd539d..f8dd53d40dcf 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -22,6 +22,7 @@ #include #include #include +#include #if IS_ENABLED(CONFIG_IPV6) /* match_sk*_wildcard == true: IPV6_ADDR_ANY equals to any IPv6 addresses @@ -708,31 +709,6 @@ struct sock *inet_csk_accept(struct sock *sk, struct proto_accept_arg *arg) release_sock(sk); - if (mem_cgroup_sockets_enabled) { - gfp_t gfp = GFP_KERNEL | __GFP_NOFAIL; - int amt = 0; - - /* atomically get the memory usage, set and charge the - * newsk->sk_memcg. - */ - lock_sock(newsk); - - mem_cgroup_sk_alloc(newsk); - if (mem_cgroup_from_sk(newsk)) { - /* The socket has not been accepted yet, no need - * to look at newsk->sk_wmem_queued. - */ - amt = sk_mem_pages(newsk->sk_forward_alloc + - atomic_read(&newsk->sk_rmem_alloc)); - } - - if (amt) - mem_cgroup_sk_charge(newsk, amt, gfp); - kmem_cache_charge(newsk, gfp); - - release_sock(newsk); - } - if (req) reqsk_put(req); diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 71a956fbfc55..dcbd49e2f8af 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -908,7 +908,8 @@ struct sk_buff *tcp_stream_alloc_skb(struct sock *sk, gfp_t gfp, } __kfree_skb(skb); } else { - sk->sk_prot->enter_memory_pressure(sk); + if (sk_should_enter_memory_pressure(sk)) + tcp_enter_memory_pressure(sk); sk_stream_moderate_sndbuf(sk); } return NULL; diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index dfbac0876d96..f7aa86661219 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -3574,12 +3574,18 @@ void sk_forced_mem_schedule(struct sock *sk, int size) delta = size - sk->sk_forward_alloc; if (delta <= 0) return; + amt = sk_mem_pages(delta); sk_forward_alloc_add(sk, amt << PAGE_SHIFT); - sk_memory_allocated_add(sk, amt); - if (mem_cgroup_sk_enabled(sk)) + if (mem_cgroup_sk_enabled(sk)) { mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge() | __GFP_NOFAIL); + + if (mem_cgroup_sk_isolated(sk)) + return; + } + + sk_memory_allocated_add(sk, amt); } /* Send a FIN. The caller locks the socket for us. diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index 9a287b75c1b3..f7487e22a3f8 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -16,6 +16,7 @@ #include #include #include +#include #include #if IS_ENABLED(CONFIG_MPTCP_IPV6) #include @@ -1016,7 +1017,7 @@ static void mptcp_enter_memory_pressure(struct sock *sk) mptcp_for_each_subflow(msk, subflow) { struct sock *ssk = mptcp_subflow_tcp_sock(subflow); - if (first) + if (first && sk_should_enter_memory_pressure(ssk)) tcp_enter_memory_pressure(ssk); sk_stream_moderate_sndbuf(ssk); diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index f672a62a9a52..6696ef837116 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -371,7 +372,8 @@ static int tls_do_allocation(struct sock *sk, if (!offload_ctx->open_record) { if (unlikely(!skb_page_frag_refill(prepend_size, pfrag, sk->sk_allocation))) { - READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk); + if (sk_should_enter_memory_pressure(sk)) + READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk); sk_stream_moderate_sndbuf(sk); return -ENOMEM; } diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h index 233de8677382..52b8c2278589 100644 --- a/tools/include/uapi/linux/bpf.h +++ b/tools/include/uapi/linux/bpf.h @@ -7182,6 +7182,7 @@ enum { TCP_BPF_SYN_MAC = 1007, /* Copy the MAC, IP[46], and TCP header */ TCP_BPF_SOCK_OPS_CB_FLAGS = 1008, /* Get or Set TCP sock ops flags */ SK_BPF_CB_FLAGS = 1009, /* Get or set sock ops flags in socket */ + SK_BPF_MEMCG_FLAGS = 1010, /* Get or Set flags saved in sk->sk_memcg */ }; enum { @@ -7204,6 +7205,11 @@ enum { */ }; +enum { + SK_BPF_MEMCG_SOCK_ISOLATED = (1UL << 0), + SK_BPF_MEMCG_FLAG_MAX = (1UL << 1), +}; + struct bpf_perf_event_value { __u64 counter; __u64 enabled; diff --git a/tools/testing/selftests/bpf/prog_tests/sk_memcg.c b/tools/testing/selftests/bpf/prog_tests/sk_memcg.c new file mode 100644 index 000000000000..2d68b00419a2 --- /dev/null +++ b/tools/testing/selftests/bpf/prog_tests/sk_memcg.c @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2025 Google LLC */ + +#include +#include "sk_memcg.skel.h" +#include "network_helpers.h" + +#define NR_SOCKETS 64 +#define NR_SEND 64 +#define BUF_SINGLE 1024 +#define BUF_TOTAL (BUF_SINGLE * NR_SEND) + +struct test_case { + char name[10]; /* protocols (%-9s) in /proc/net/protocols, see proto_seq_printf(). */ + int family; + int type; + int (*create_sockets)(struct test_case *test_case, int sk[], int len); +}; + +static int tcp_create_sockets(struct test_case *test_case, int sk[], int len) +{ + int server, i; + + server = start_server(test_case->family, test_case->type, NULL, 0, 0); + ASSERT_GE(server, 0, "start_server_str"); + + for (i = 0; i < len / 2; i++) { + sk[i * 2] = connect_to_fd(server, 0); + if (!ASSERT_GE(sk[i * 2], 0, "connect_to_fd")) + return sk[i * 2]; + + sk[i * 2 + 1] = accept(server, NULL, NULL); + if (!ASSERT_GE(sk[i * 2 + 1], 0, "accept")) + return sk[i * 2 + 1]; + } + + close(server); + + return 0; +} + +static int udp_create_sockets(struct test_case *test_case, int sk[], int len) +{ + int i, err, rcvbuf = BUF_TOTAL; + + for (i = 0; i < len / 2; i++) { + sk[i * 2] = start_server(test_case->family, test_case->type, NULL, 0, 0); + if (!ASSERT_GE(sk[i * 2], 0, "start_server")) + return sk[i * 2]; + + sk[i * 2 + 1] = connect_to_fd(sk[i * 2], 0); + if (!ASSERT_GE(sk[i * 2 + 1], 0, "connect_to_fd")) + return sk[i * 2 + 1]; + + err = connect_fd_to_fd(sk[i * 2], sk[i * 2 + 1], 0); + if (!ASSERT_EQ(err, 0, "connect_fd_to_fd")) + return err; + + err = setsockopt(sk[i * 2], SOL_SOCKET, SO_RCVBUF, &rcvbuf, sizeof(int)); + if (!ASSERT_EQ(err, 0, "setsockopt(SO_RCVBUF)")) + return err; + + err = setsockopt(sk[i * 2 + 1], SOL_SOCKET, SO_RCVBUF, &rcvbuf, sizeof(int)); + if (!ASSERT_EQ(err, 0, "setsockopt(SO_RCVBUF)")) + return err; + } + + return 0; +} + +static int get_memory_allocated(struct test_case *test_case) +{ + long memory_allocated = -1; + char *line = NULL; + size_t unused; + FILE *f; + + f = fopen("/proc/net/protocols", "r"); + if (!ASSERT_OK_PTR(f, "fopen")) + goto out; + + while (getline(&line, &unused, f) != -1) { + unsigned int unused_0; + int unused_1; + int ret; + + if (strncmp(line, test_case->name, sizeof(test_case->name))) + continue; + + ret = sscanf(line + sizeof(test_case->name), "%4u %6d %6ld", + &unused_0, &unused_1, &memory_allocated); + ASSERT_EQ(ret, 3, "sscanf"); + break; + } + + ASSERT_NEQ(memory_allocated, -1, "get_memory_allocated"); + + free(line); + fclose(f); +out: + return memory_allocated; +} + +static int check_isolated(struct test_case *test_case, bool isolated) +{ + char buf[BUF_SINGLE] = {}; + long memory_allocated[2]; + int sk[NR_SOCKETS] = {}; + int err = -1, i, j; + + memory_allocated[0] = get_memory_allocated(test_case); + if (!ASSERT_GE(memory_allocated[0], 0, "memory_allocated[0]")) + goto out; + + err = test_case->create_sockets(test_case, sk, ARRAY_SIZE(sk)); + if (err) + goto close; + + /* Must allocate pages >= net.core.mem_pcpu_rsv */ + for (i = 0; i < ARRAY_SIZE(sk); i++) { + for (j = 0; j < NR_SEND; j++) { + int bytes = send(sk[i], buf, sizeof(buf), 0); + + /* Avoid too noisy logs when something failed. */ + if (bytes != sizeof(buf)) + ASSERT_EQ(bytes, sizeof(buf), "send"); + } + } + + memory_allocated[1] = get_memory_allocated(test_case); + if (!ASSERT_GE(memory_allocated[1], 0, "memory_allocated[1]")) + goto close; + + if (isolated) { + ASSERT_LE(memory_allocated[1], memory_allocated[0] + 10, "isolated"); + } else { + /* By default, net.core.mem_pcpu_rsv == 256 pages */ + ASSERT_GT(memory_allocated[1], memory_allocated[0] + 256, "not isolated"); + } + +close: + for (i = 0; i < ARRAY_SIZE(sk); i++) + close(sk[i]); + + if (test_case->type == SOCK_DGRAM) { + /* Give 150ms to let RCU destruct UDP sockets */ + usleep(150 * 1000); + } +out: + return err; +} + +void run_test(struct test_case *test_case) +{ + struct sk_memcg *skel; + int cgroup, err; + + skel = sk_memcg__open_and_load(); + if (!ASSERT_OK_PTR(skel, "open_and_load")) + return; + + cgroup = test__join_cgroup("/sk_memcg"); + if (!ASSERT_GE(cgroup, 0, "join_cgroup")) + goto destroy_skel; + + err = check_isolated(test_case, false); + if (!ASSERT_EQ(err, 0, "test_isolated(false)")) + goto close_cgroup; + + skel->links.sock_create = bpf_program__attach_cgroup(skel->progs.sock_create, cgroup); + if (!ASSERT_OK_PTR(skel->links.sock_create, "attach_cgroup(sock_create)")) + goto close_cgroup; + + err = check_isolated(test_case, true); + ASSERT_EQ(err, 0, "test_isolated(false)"); + +close_cgroup: + close(cgroup); +destroy_skel: + sk_memcg__destroy(skel); +} + +struct test_case test_cases[] = { + { + .name = "TCP ", + .family = AF_INET, + .type = SOCK_STREAM, + .create_sockets = tcp_create_sockets, + }, + { + .name = "UDP ", + .family = AF_INET, + .type = SOCK_DGRAM, + .create_sockets = udp_create_sockets, + }, + { + .name = "TCPv6 ", + .family = AF_INET6, + .type = SOCK_STREAM, + .create_sockets = tcp_create_sockets, + }, + { + .name = "UDPv6 ", + .family = AF_INET6, + .type = SOCK_DGRAM, + .create_sockets = udp_create_sockets, + }, +}; + +void serial_test_sk_memcg(void) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(test_cases); i++) { + test__start_subtest(test_cases[i].name); + run_test(&test_cases[i]); + } +} diff --git a/tools/testing/selftests/bpf/progs/sk_memcg.c b/tools/testing/selftests/bpf/progs/sk_memcg.c new file mode 100644 index 000000000000..a613c1deeede --- /dev/null +++ b/tools/testing/selftests/bpf/progs/sk_memcg.c @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2025 Google LLC */ + +#include "bpf_tracing_net.h" +#include +#include + +SEC("cgroup/sock_create") +int sock_create(struct bpf_sock *ctx) +{ + u32 flags = SK_BPF_MEMCG_SOCK_ISOLATED; + int err; + + err = bpf_setsockopt(ctx, SOL_SOCKET, SK_BPF_MEMCG_FLAGS, + &flags, sizeof(flags)); + if (err) + goto err; + + flags = 0; + + err = bpf_getsockopt(ctx, SOL_SOCKET, SK_BPF_MEMCG_FLAGS, + &flags, sizeof(flags)); + if (err) + goto err; + + if (flags != SK_BPF_MEMCG_SOCK_ISOLATED) { + err = -EINVAL; + goto err; + } + + return 1; + +err: + bpf_set_retval(err); + return 0; +} + +char LICENSE[] SEC("license") = "GPL";