Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions include/net/proto_memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
if (!sk->sk_prot->memory_pressure)
return false;

if (mem_cgroup_sk_enabled(sk) &&
mem_cgroup_sk_under_memory_pressure(sk))
return true;
if (mem_cgroup_sk_enabled(sk)) {
if (mem_cgroup_sk_under_memory_pressure(sk))
return true;

if (mem_cgroup_sk_isolated(sk))
return false;
}

return !!READ_ONCE(*sk->sk_prot->memory_pressure);
}

static inline bool sk_should_enter_memory_pressure(struct sock *sk)
{
return !mem_cgroup_sk_enabled(sk) || !mem_cgroup_sk_isolated(sk);
}

static inline long
proto_memory_allocated(const struct proto *prot)
{
Expand Down
50 changes: 50 additions & 0 deletions include/net/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -2596,17 +2596,53 @@ static inline gfp_t gfp_memcg_charge(void)
return in_softirq() ? GFP_ATOMIC : GFP_KERNEL;
}

#define SK_BPF_MEMCG_FLAG_MASK (SK_BPF_MEMCG_FLAG_MAX - 1)
#define SK_BPF_MEMCG_PTR_MASK ~SK_BPF_MEMCG_FLAG_MASK

#ifdef CONFIG_MEMCG
static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk)
{
#ifdef CONFIG_CGROUP_BPF
unsigned long val = (unsigned long)sk->sk_memcg;

val &= SK_BPF_MEMCG_PTR_MASK;
return (struct mem_cgroup *)val;
#else
return sk->sk_memcg;
#endif
}

static inline void mem_cgroup_sk_set_flags(struct sock *sk, unsigned short flags)
{
#ifdef CONFIG_CGROUP_BPF
unsigned long val = (unsigned long)mem_cgroup_from_sk(sk);

val |= flags;
sk->sk_memcg = (struct mem_cgroup *)val;
#endif
}

static inline unsigned short mem_cgroup_sk_get_flags(const struct sock *sk)
{
#ifdef CONFIG_CGROUP_BPF
unsigned long val = (unsigned long)sk->sk_memcg;

return val & SK_BPF_MEMCG_FLAG_MASK;
#else
return 0;
#endif
}

static inline bool mem_cgroup_sk_enabled(const struct sock *sk)
{
return mem_cgroup_sockets_enabled && mem_cgroup_from_sk(sk);
}

static inline bool mem_cgroup_sk_isolated(const struct sock *sk)
{
return mem_cgroup_sk_get_flags(sk) & SK_BPF_MEMCG_SOCK_ISOLATED;
}

static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk)
{
struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
Expand All @@ -2629,11 +2665,25 @@ static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk)
return NULL;
}

static inline void mem_cgroup_sk_set_flags(struct sock *sk, unsigned short flags)
{
}

static inline unsigned short mem_cgroup_sk_get_flags(const struct sock *sk)
{
return 0;
}

static inline bool mem_cgroup_sk_enabled(const struct sock *sk)
{
return false;
}

static inline bool mem_cgroup_sk_isolated(const struct sock *sk)
{
return false;
}

static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk)
{
return false;
Expand Down
10 changes: 7 additions & 3 deletions include/net/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,13 @@ extern unsigned long tcp_memory_pressure;
/* optimized version of sk_under_memory_pressure() for TCP sockets */
static inline bool tcp_under_memory_pressure(const struct sock *sk)
{
if (mem_cgroup_sk_enabled(sk) &&
mem_cgroup_sk_under_memory_pressure(sk))
return true;
if (mem_cgroup_sk_enabled(sk)) {
if (mem_cgroup_sk_under_memory_pressure(sk))
return true;

if (mem_cgroup_sk_isolated(sk))
return false;
}

return READ_ONCE(tcp_memory_pressure);
}
Expand Down
6 changes: 6 additions & 0 deletions include/uapi/linux/bpf.h
Original file line number Diff line number Diff line change
Expand Up @@ -7182,6 +7182,7 @@ enum {
TCP_BPF_SYN_MAC = 1007, /* Copy the MAC, IP[46], and TCP header */
TCP_BPF_SOCK_OPS_CB_FLAGS = 1008, /* Get or Set TCP sock ops flags */
SK_BPF_CB_FLAGS = 1009, /* Get or set sock ops flags in socket */
SK_BPF_MEMCG_FLAGS = 1010, /* Get or Set flags saved in sk->sk_memcg */
};

enum {
Expand All @@ -7204,6 +7205,11 @@ enum {
*/
};

enum {
SK_BPF_MEMCG_SOCK_ISOLATED = (1UL << 0),
SK_BPF_MEMCG_FLAG_MAX = (1UL << 1),
};

struct bpf_perf_event_value {
__u64 counter;
__u64 enabled;
Expand Down
91 changes: 90 additions & 1 deletion net/core/filter.c
Original file line number Diff line number Diff line change
Expand Up @@ -5267,6 +5267,35 @@ static int sk_bpf_set_get_cb_flags(struct sock *sk, char *optval, bool getopt)
return 0;
}

static int sk_bpf_get_memcg_flags(struct sock *sk, char *optval)
{
if (!sk_has_account(sk))
return -EOPNOTSUPP;

*(u32 *)optval = mem_cgroup_sk_get_flags(sk);

return 0;
}

static int sk_bpf_set_memcg_flags(struct sock *sk, char *optval, int optlen)
{
u32 flags;

if (optlen != sizeof(u32))
return -EINVAL;

if (!sk_has_account(sk))
return -EOPNOTSUPP;

flags = *(u32 *)optval;
if (flags >= SK_BPF_MEMCG_FLAG_MAX)
return -EINVAL;

mem_cgroup_sk_set_flags(sk, flags);

return 0;
}

static int sol_socket_sockopt(struct sock *sk, int optname,
char *optval, int *optlen,
bool getopt)
Expand All @@ -5284,6 +5313,7 @@ static int sol_socket_sockopt(struct sock *sk, int optname,
case SO_BINDTOIFINDEX:
case SO_TXREHASH:
case SK_BPF_CB_FLAGS:
case SK_BPF_MEMCG_FLAGS:
if (*optlen != sizeof(int))
return -EINVAL;
break;
Expand All @@ -5293,8 +5323,15 @@ static int sol_socket_sockopt(struct sock *sk, int optname,
return -EINVAL;
}

if (optname == SK_BPF_CB_FLAGS)
switch (optname) {
case SK_BPF_CB_FLAGS:
return sk_bpf_set_get_cb_flags(sk, optval, getopt);
case SK_BPF_MEMCG_FLAGS:
if (!IS_ENABLED(CONFIG_MEMCG) || !getopt)
return -EOPNOTSUPP;

return sk_bpf_get_memcg_flags(sk, optval);
}

if (getopt) {
if (optname == SO_BINDTODEVICE)
Expand Down Expand Up @@ -5723,6 +5760,44 @@ static const struct bpf_func_proto bpf_sock_addr_getsockopt_proto = {
.arg5_type = ARG_CONST_SIZE,
};

BPF_CALL_5(bpf_unlocked_sock_setsockopt, struct sock *, sk, int, level,
int, optname, char *, optval, int, optlen)
{
if (IS_ENABLED(CONFIG_MEMCG) &&
level == SOL_SOCKET && optname == SK_BPF_MEMCG_FLAGS)
return sk_bpf_set_memcg_flags(sk, optval, optlen);

return __bpf_setsockopt(sk, level, optname, optval, optlen);
}

static const struct bpf_func_proto bpf_unlocked_sock_setsockopt_proto = {
.func = bpf_unlocked_sock_setsockopt,
.gpl_only = false,
.ret_type = RET_INTEGER,
.arg1_type = ARG_PTR_TO_CTX,
.arg2_type = ARG_ANYTHING,
.arg3_type = ARG_ANYTHING,
.arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY,
.arg5_type = ARG_CONST_SIZE,
};

BPF_CALL_5(bpf_unlocked_sock_getsockopt, struct sock *, sk, int, level,
int, optname, char *, optval, int, optlen)
{
return __bpf_getsockopt(sk, level, optname, optval, optlen);
}

static const struct bpf_func_proto bpf_unlocked_sock_getsockopt_proto = {
.func = bpf_unlocked_sock_getsockopt,
.gpl_only = false,
.ret_type = RET_INTEGER,
.arg1_type = ARG_PTR_TO_CTX,
.arg2_type = ARG_ANYTHING,
.arg3_type = ARG_ANYTHING,
.arg4_type = ARG_PTR_TO_UNINIT_MEM,
.arg5_type = ARG_CONST_SIZE,
};

BPF_CALL_5(bpf_sock_ops_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
int, level, int, optname, char *, optval, int, optlen)
{
Expand Down Expand Up @@ -8051,6 +8126,20 @@ sock_filter_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
return &bpf_sk_storage_get_cg_sock_proto;
case BPF_FUNC_ktime_get_coarse_ns:
return &bpf_ktime_get_coarse_ns_proto;
case BPF_FUNC_setsockopt:
switch (prog->expected_attach_type) {
case BPF_CGROUP_INET_SOCK_CREATE:
return &bpf_unlocked_sock_setsockopt_proto;
default:
return NULL;
}
case BPF_FUNC_getsockopt:
switch (prog->expected_attach_type) {
case BPF_CGROUP_INET_SOCK_CREATE:
return &bpf_unlocked_sock_getsockopt_proto;
default:
return NULL;
}
default:
return bpf_base_func_proto(func_id, prog);
}
Expand Down
65 changes: 45 additions & 20 deletions net/core/sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -1046,17 +1046,21 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
if (!charged)
return -ENOMEM;

/* pre-charge to forward_alloc */
sk_memory_allocated_add(sk, pages);
allocated = sk_memory_allocated(sk);
/* If the system goes into memory pressure with this
* precharge, give up and return error.
*/
if (allocated > sk_prot_mem_limits(sk, 1)) {
sk_memory_allocated_sub(sk, pages);
mem_cgroup_sk_uncharge(sk, pages);
return -ENOMEM;
if (!mem_cgroup_sk_isolated(sk)) {
/* pre-charge to forward_alloc */
sk_memory_allocated_add(sk, pages);
allocated = sk_memory_allocated(sk);

/* If the system goes into memory pressure with this
* precharge, give up and return error.
*/
if (allocated > sk_prot_mem_limits(sk, 1)) {
sk_memory_allocated_sub(sk, pages);
mem_cgroup_sk_uncharge(sk, pages);
return -ENOMEM;
}
}

sk_forward_alloc_add(sk, pages << PAGE_SHIFT);

WRITE_ONCE(sk->sk_reserved_mem,
Expand Down Expand Up @@ -2515,6 +2519,7 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
#ifdef CONFIG_MEMCG
/* sk->sk_memcg will be populated at accept() time */
newsk->sk_memcg = NULL;
mem_cgroup_sk_set_flags(newsk, mem_cgroup_sk_get_flags(sk));
#endif

cgroup_sk_clone(&newsk->sk_cgrp_data);
Expand Down Expand Up @@ -3153,8 +3158,11 @@ bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag)
if (likely(skb_page_frag_refill(32U, pfrag, sk->sk_allocation)))
return true;

sk_enter_memory_pressure(sk);
if (sk_should_enter_memory_pressure(sk))
sk_enter_memory_pressure(sk);

sk_stream_moderate_sndbuf(sk);

return false;
}
EXPORT_SYMBOL(sk_page_frag_refill);
Expand Down Expand Up @@ -3267,18 +3275,30 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
{
bool memcg_enabled = false, charged = false;
struct proto *prot = sk->sk_prot;
long allocated;

sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);
long allocated = 0;

if (mem_cgroup_sk_enabled(sk)) {
bool isolated = mem_cgroup_sk_isolated(sk);

memcg_enabled = true;
charged = mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge());
if (!charged)

if (isolated && charged)
return 1;

if (!charged) {
if (!isolated) {
sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);
}

goto suppress_allocation;
}
}

sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);

/* Under limit. */
if (allocated <= sk_prot_mem_limits(sk, 0)) {
sk_leave_memory_pressure(sk);
Expand Down Expand Up @@ -3357,7 +3377,8 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)

trace_sock_exceed_buf_limit(sk, prot, allocated, kind);

sk_memory_allocated_sub(sk, amt);
if (allocated)
sk_memory_allocated_sub(sk, amt);

if (charged)
mem_cgroup_sk_uncharge(sk, amt);
Expand Down Expand Up @@ -3396,11 +3417,15 @@ EXPORT_SYMBOL(__sk_mem_schedule);
*/
void __sk_mem_reduce_allocated(struct sock *sk, int amount)
{
sk_memory_allocated_sub(sk, amount);

if (mem_cgroup_sk_enabled(sk))
if (mem_cgroup_sk_enabled(sk)) {
mem_cgroup_sk_uncharge(sk, amount);

if (mem_cgroup_sk_isolated(sk))
return;
}

sk_memory_allocated_sub(sk, amount);

if (sk_under_global_memory_pressure(sk) &&
(sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
sk_leave_memory_pressure(sk);
Expand Down
Loading
Loading