Skip to content

bpf: Fix use-after-free of sockmap #8686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions net/core/skmsg.c
Original file line number Diff line number Diff line change
@@ -655,6 +655,14 @@ static void sk_psock_backlog(struct work_struct *work)
bool ingress;
int ret;

/* Increment the psock refcnt to synchronize with close(fd) path in
* sock_map_close(), ensuring we wait for backlog thread completion
* before sk_socket freed. If refcnt increment fails, it indicates
* sock_map_close() completed with sk_socket potentially already freed.
*/
if (!sk_psock_get(psock->sk))
return;

mutex_lock(&psock->work_mutex);
if (unlikely(state->len)) {
len = state->len;
@@ -702,6 +710,7 @@ static void sk_psock_backlog(struct work_struct *work)
}
end:
mutex_unlock(&psock->work_mutex);
sk_psock_put(psock->sk, psock);
}

struct sk_psock *sk_psock_init(struct sock *sk, int node)
@@ -1222,17 +1231,24 @@ static int sk_psock_verdict_recv(struct sock *sk, struct sk_buff *skb)

static void sk_psock_verdict_data_ready(struct sock *sk)
{
struct socket *sock = sk->sk_socket;
struct socket *sock;
const struct proto_ops *ops;
int copied;

trace_sk_data_ready(sk);

if (unlikely(!sock))
rcu_read_lock();
sock = sk->sk_socket;
if (unlikely(!sock)) {
rcu_read_unlock();
return;
}
ops = READ_ONCE(sock->ops);
if (!ops || !ops->read_skb)
if (!ops || !ops->read_skb) {
rcu_read_unlock();
return;
}
rcu_read_unlock();
copied = ops->read_skb(sk, sk_psock_verdict_recv);
if (copied >= 0) {
struct sk_psock *psock;
13 changes: 12 additions & 1 deletion tools/testing/selftests/bpf/prog_tests/socket_helpers.h
Original file line number Diff line number Diff line change
@@ -313,11 +313,22 @@ static inline int recv_timeout(int fd, void *buf, size_t len, int flags,

static inline int create_pair(int family, int sotype, int *p0, int *p1)
{
__close_fd int s, c = -1, p = -1;
__close_fd int s = -1, c = -1, p = -1;
struct sockaddr_storage addr;
socklen_t len = sizeof(addr);
int err;

if (family == AF_UNIX) {
int fds[2];

err = socketpair(family, sotype, 0, fds);
if (!err) {
*p0 = fds[0];
*p1 = fds[1];
}
return err;
}

s = socket_loopback(family, sotype);
if (s < 0)
return s;
60 changes: 60 additions & 0 deletions tools/testing/selftests/bpf/prog_tests/sockmap_basic.c
Original file line number Diff line number Diff line change
@@ -1042,6 +1042,59 @@ static void test_sockmap_vsock_unconnected(void)
xclose(map);
}

void *close_thread(void *arg)
{
int *fd = (int *)arg;

sleep(1);
close(*fd);
*fd = -1;
return NULL;
}

void test_sockmap_with_close_on_write(int family, int sotype)
{
struct test_sockmap_pass_prog *skel;
int err, map, verdict;
pthread_t tid;
int zero = 0;
int c = -1, p = -1;

skel = test_sockmap_pass_prog__open_and_load();
if (!ASSERT_OK_PTR(skel, "open_and_load"))
return;

verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
map = bpf_map__fd(skel->maps.sock_map_rx);

err = bpf_prog_attach(verdict, map, BPF_SK_SKB_STREAM_VERDICT, 0);
if (!ASSERT_OK(err, "bpf_prog_attach"))
goto out;

err = create_pair(family, sotype, &c, &p);
if (!ASSERT_OK(err, "create_pair"))
goto out;

err = bpf_map_update_elem(map, &zero, &p, BPF_ANY);
if (!ASSERT_OK(err, "bpf_map_update_elem"))
goto out;

err = pthread_create(&tid, 0, close_thread, &p);
if (!ASSERT_OK(err, "pthread_create"))
goto out;

while (p > 0)
send(c, "a", 1, MSG_NOSIGNAL);

pthread_join(tid, NULL);
out:
if (c > 0)
close(c);
if (p > 0)
close(p);
test_sockmap_pass_prog__destroy(skel);
}

void test_sockmap_basic(void)
{
if (test__start_subtest("sockmap create_update_free"))
@@ -1108,4 +1161,11 @@ void test_sockmap_basic(void)
test_sockmap_skb_verdict_vsock_poll();
if (test__start_subtest("sockmap vsock unconnected"))
test_sockmap_vsock_unconnected();
if (test__start_subtest("sockmap with write on close")) {
test_sockmap_with_close_on_write(AF_UNIX, SOCK_STREAM);
test_sockmap_with_close_on_write(AF_UNIX, SOCK_DGRAM);
test_sockmap_with_close_on_write(AF_INET, SOCK_STREAM);
test_sockmap_with_close_on_write(AF_INET, SOCK_DGRAM);
test_sockmap_with_close_on_write(AF_VSOCK, SOCK_STREAM);
}
}