Skip to content

bpf: Add kfuncs for read-only string operations #8709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 299 additions & 0 deletions kernel/bpf/helpers.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: GPL-2.0-only
/* Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
*/
#include "linux/uaccess.h"
#include <linux/bpf.h>
#include <linux/btf.h>
#include <linux/bpf-cgroup.h>
Expand Down Expand Up @@ -3195,6 +3196,291 @@ __bpf_kfunc void bpf_local_irq_restore(unsigned long *flags__irq_flag)
local_irq_restore(*flags__irq_flag);
}

/* Kfuncs for string operations.
*
* Since strings are not necessarily %NUL-terminated, we cannot directly call
* in-kernel implementations. Instead, unbounded variants are open-coded with
* using __get_kernel_nofault instead of plain dereference to make them safe.
* Bounded variants use params with the __sz suffix so safety is assured by the
* verifier and we can use the in-kernel (potentially optimized) functions.
*/

/**
* bpf_strcmp - Compare two strings
* @cs: One string
* @ct: Another string
*/
__bpf_kfunc int bpf_strcmp(const char *cs, const char *ct)
{
int i = 0, ret = 0;
char c1, c2;

pagefault_disable();
while (i++ < XATTR_SIZE_MAX) {
__get_kernel_nofault(&c1, cs++, char, cs_out);
__get_kernel_nofault(&c2, ct++, char, ct_out);
if (c1 != c2) {
ret = c1 < c2 ? -1 : 1;
goto out;
}
if (!c1)
goto out;
}
cs_out:
ret = -1;
goto out;
ct_out:
ret = 1;
out:
pagefault_enable();
return ret;
}

/**
* bpf_strchr - Find the first occurrence of a character in a string
* @s: The string to be searched
* @c: The character to search for
*
* Note that the %NUL-terminator is considered part of the string, and can
* be searched for.
*/
__bpf_kfunc char *bpf_strchr(const char *s, int c)
{
char *ret = NULL;
int i = 0;
char sc;

pagefault_disable();
while (i++ < XATTR_SIZE_MAX) {
__get_kernel_nofault(&sc, s, char, out);
if (sc == (char)c) {
ret = (char *)s;
break;
}
if (sc == '\0')
break;
s++;
}
out:
pagefault_enable();
return ret;
}

/**
* bpf_strchrnul - Find and return a character in a string, or end of string
* @s: The string to be searched
* @c: The character to search for
*
* Returns pointer to first occurrence of 'c' in s. If c is not found, then
* return a pointer to the null byte at the end of s.
*/
__bpf_kfunc char *bpf_strchrnul(const char *s, int c)
{
char *ret = NULL;
int i = 0;
char sc;

pagefault_disable();
while (i++ < XATTR_SIZE_MAX) {
__get_kernel_nofault(&sc, s, char, out);
if (sc == '\0' || sc == (char)c) {
ret = (char *)s;
break;
}
s++;
}
out:
pagefault_enable();
return ret;
}

/**
* bpf_strnchr - Find a character in a length limited string
* @s: The string to be searched
* @s__sz: The number of characters to be searched
* @c: The character to search for
*
* Note that the %NUL-terminator is considered part of the string, and can
* be searched for.
*/
__bpf_kfunc char *bpf_strnchr(void *s, u32 s__sz, int c)
{
return strnchr(s, s__sz, c);
}

/**
* bpf_strnchrnul - Find and return a character in a length limited string,
* or end of string
* @s: The string to be searched
* @s__sz: The number of characters to be searched
* @c: The character to search for
*
* Returns pointer to the first occurrence of 'c' in s. If c is not found,
* then return a pointer to the last character of the string.
*/
__bpf_kfunc char *bpf_strnchrnul(void *s, u32 s__sz, int c)
{
return strnchrnul(s, s__sz, c);
}

/**
* bpf_strrchr - Find the last occurrence of a character in a string
* @s: The string to be searched
* @c: The character to search for
*/
__bpf_kfunc char *bpf_strrchr(const char *s, int c)
{
char *ret = NULL;
int i = 0;
char sc;

pagefault_disable();
while (i++ < XATTR_SIZE_MAX) {
__get_kernel_nofault(&sc, s, char, out);
if (sc == '\0')
break;
if (sc == (char)c)
ret = (char *)s;
s++;
}
out:
pagefault_enable();
return (char *)ret;
}

__bpf_kfunc size_t bpf_strlen(const char *s)
{
int i = 0;
char c;

pagefault_disable();
while (i < XATTR_SIZE_MAX) {
__get_kernel_nofault(&c, s++, char, out);
if (c == '\0')
break;
i++;
}
out:
pagefault_enable();
return i;
}

__bpf_kfunc size_t bpf_strnlen(void *s, u32 s__sz)
{
return strnlen(s, s__sz);
}

/**
* bpf_strspn - Calculate the length of the initial substring of @s which only contain letters in @accept
* @s: The string to be searched
* @accept: The string to search for
*/
__bpf_kfunc size_t bpf_strspn(const char *s, const char *accept)
{
int i = 0;
char c;

pagefault_disable();
while (i < XATTR_SIZE_MAX) {
__get_kernel_nofault(&c, s++, char, out);
if (c == '\0' || !bpf_strchr(accept, c))
break;
i++;
}
out:
pagefault_enable();
return i;
}

/**
* strcspn - Calculate the length of the initial substring of @s which does not contain letters in @reject
* @s: The string to be searched
* @reject: The string to avoid
*/
__bpf_kfunc size_t bpf_strcspn(const char *s, const char *reject)
{
int i = 0;
char c;

pagefault_disable();
while (i < XATTR_SIZE_MAX) {
__get_kernel_nofault(&c, s++, char, out);
if (c == '\0' || bpf_strchr(reject, c))
break;
i++;
}
out:
pagefault_enable();
return i;
}

/**
* bpf_strpbrk - Find the first occurrence of a set of characters
* @cs: The string to be searched
* @ct: The characters to search for
*/
__bpf_kfunc char *bpf_strpbrk(const char *cs, const char *ct)
{
char *ret = NULL;
int i = 0;
char c;

pagefault_disable();
while (i++ < XATTR_SIZE_MAX) {
__get_kernel_nofault(&c, cs, char, out);
if (c == '\0')
break;
if (bpf_strchr(ct, c)) {
ret = (char *)cs;
break;
}
cs++;
}
out:
pagefault_enable();
return ret;
}

/**
* bpf_strstr - Find the first substring in a %NUL terminated string
* @s1: The string to be searched
* @s2: The string to search for
*/
__bpf_kfunc char *bpf_strstr(const char *s1, const char *s2)
{
size_t l1, l2;

l2 = bpf_strlen(s2);
if (!l2)
return (char *)s1;
l1 = bpf_strlen(s1);
while (l1 >= l2) {
l1--;
if (!memcmp(s1, s2, l2))
return (char *)s1;
s1++;
}
return NULL;
}

/**
* bpf_strnstr - Find the first substring in a length-limited string
* @s1: The string to be searched
* @s1__sz: The size of @s1
* @s2: The string to search for
* @s2__sz: The size of @s2
*/
__bpf_kfunc char *bpf_strnstr(void *s1, u32 s1__sz, void *s2, u32 s2__sz)
{
/* strnstr() uses strlen() to get the length of s2. Since this is not
* safe in BPF context for non-%NUL-terminated strings, use strnlen
* first to make it safe.
*/
if (strnlen(s2, s2__sz) == s2__sz)
return NULL;
return strnstr(s1, s2, s1__sz);
}

__bpf_kfunc_end_defs();

BTF_KFUNCS_START(generic_btf_ids)
Expand Down Expand Up @@ -3295,6 +3581,19 @@ BTF_ID_FLAGS(func, bpf_iter_kmem_cache_next, KF_ITER_NEXT | KF_RET_NULL | KF_SLE
BTF_ID_FLAGS(func, bpf_iter_kmem_cache_destroy, KF_ITER_DESTROY | KF_SLEEPABLE)
BTF_ID_FLAGS(func, bpf_local_irq_save)
BTF_ID_FLAGS(func, bpf_local_irq_restore)
BTF_ID_FLAGS(func, bpf_strcmp);
BTF_ID_FLAGS(func, bpf_strchr);
BTF_ID_FLAGS(func, bpf_strchrnul);
BTF_ID_FLAGS(func, bpf_strnchr);
BTF_ID_FLAGS(func, bpf_strnchrnul);
BTF_ID_FLAGS(func, bpf_strrchr);
BTF_ID_FLAGS(func, bpf_strlen);
BTF_ID_FLAGS(func, bpf_strnlen);
BTF_ID_FLAGS(func, bpf_strspn);
BTF_ID_FLAGS(func, bpf_strcspn);
BTF_ID_FLAGS(func, bpf_strpbrk);
BTF_ID_FLAGS(func, bpf_strstr);
BTF_ID_FLAGS(func, bpf_strnstr);
BTF_KFUNCS_END(common_btf_ids)

static const struct btf_kfunc_id_set common_kfunc_set = {
Expand Down
2 changes: 2 additions & 0 deletions tools/testing/selftests/bpf/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ $(OUTPUT)/bench_local_storage_create.o: $(OUTPUT)/bench_local_storage_create.ske
$(OUTPUT)/bench_bpf_hashmap_lookup.o: $(OUTPUT)/bpf_hashmap_lookup.skel.h
$(OUTPUT)/bench_htab_mem.o: $(OUTPUT)/htab_mem_bench.skel.h
$(OUTPUT)/bench_bpf_crypto.o: $(OUTPUT)/crypto_bench.skel.h
$(OUTPUT)/bench_string_kfuncs.o: $(OUTPUT)/string_kfuncs_bench.skel.h
$(OUTPUT)/bench.o: bench.h testing_helpers.h $(BPFOBJ)
$(OUTPUT)/bench: LDLIBS += -lm
$(OUTPUT)/bench: $(OUTPUT)/bench.o \
Expand All @@ -831,6 +832,7 @@ $(OUTPUT)/bench: $(OUTPUT)/bench.o \
$(OUTPUT)/bench_local_storage_create.o \
$(OUTPUT)/bench_htab_mem.o \
$(OUTPUT)/bench_bpf_crypto.o \
$(OUTPUT)/bench_string_kfuncs.o \
#
$(call msg,BINARY,,$@)
$(Q)$(CC) $(CFLAGS) $(LDFLAGS) $(filter %.a %.o,$^) $(LDLIBS) -o $@
Expand Down
21 changes: 21 additions & 0 deletions tools/testing/selftests/bpf/bench.c
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ extern struct argp bench_local_storage_create_argp;
extern struct argp bench_htab_mem_argp;
extern struct argp bench_trigger_batch_argp;
extern struct argp bench_crypto_argp;
extern struct argp bench_string_kfuncs_argp;

static const struct argp_child bench_parsers[] = {
{ &bench_ringbufs_argp, 0, "Ring buffers benchmark", 0 },
Expand All @@ -297,6 +298,7 @@ static const struct argp_child bench_parsers[] = {
{ &bench_htab_mem_argp, 0, "hash map memory benchmark", 0 },
{ &bench_trigger_batch_argp, 0, "BPF triggering benchmark", 0 },
{ &bench_crypto_argp, 0, "bpf crypto benchmark", 0 },
{ &bench_string_kfuncs_argp, 0, "string kfuncs benchmark", 0 },
{},
};

Expand Down Expand Up @@ -550,6 +552,16 @@ extern const struct bench bench_htab_mem;
extern const struct bench bench_crypto_encrypt;
extern const struct bench bench_crypto_decrypt;

/* string kfunc benchmarks */
extern const struct bench bench_string_kfuncs_strlen;
extern const struct bench bench_string_kfuncs_strnlen;
extern const struct bench bench_string_kfuncs_strchr;
extern const struct bench bench_string_kfuncs_strnchr;
extern const struct bench bench_string_kfuncs_strchrnul;
extern const struct bench bench_string_kfuncs_strnchrnul;
extern const struct bench bench_string_kfuncs_strstr;
extern const struct bench bench_string_kfuncs_strnstr;

static const struct bench *benchs[] = {
&bench_count_global,
&bench_count_local,
Expand Down Expand Up @@ -609,6 +621,15 @@ static const struct bench *benchs[] = {
&bench_htab_mem,
&bench_crypto_encrypt,
&bench_crypto_decrypt,
/* string kfuncs */
&bench_string_kfuncs_strlen,
&bench_string_kfuncs_strnlen,
&bench_string_kfuncs_strchr,
&bench_string_kfuncs_strnchr,
&bench_string_kfuncs_strchrnul,
&bench_string_kfuncs_strnchrnul,
&bench_string_kfuncs_strstr,
&bench_string_kfuncs_strnstr,
};

static void find_benchmark(void)
Expand Down
Loading
Loading