diff --git a/osprey_worker/src/osprey/engine/stdlib/udfs/string.py b/osprey_worker/src/osprey/engine/stdlib/udfs/string.py index 720da9ce..12fe2a2b 100644 --- a/osprey_worker/src/osprey/engine/stdlib/udfs/string.py +++ b/osprey_worker/src/osprey/engine/stdlib/udfs/string.py @@ -400,3 +400,29 @@ def execute(self, execution_context: ExecutionContext, arguments: StringArgument # return any valid urls encountered in the message return list(valid_urls) + + +_TOKEN_PATTERN = re.compile(r"[\w]+(?:'[\w]+)?", re.UNICODE) + + +def tokenize_text(s: str) -> list[str]: + # replaces "curly" apostrophes with a normal apostrophe for a simpler regex + s = s.replace('\u2019', "'").replace('\u02bc', "'") + return _TOKEN_PATTERN.findall(s.lower()) + + +class StringTokenize(UDFBase[StringArguments, list[str]]): + """ + Used to convert the given string into a list of individual tokens. Returns a list of individual + tokens split by spaces and punctuation marks. + + Note that StringTokenize does not split on a single apostrophe found inside a word (e.g. contractions). + For example, the string "don't go" would result in ["don't", "go"]. Tokens are sequences of word + characters with at most one internal apostrophe, and the string "do''not''go" would result in + ["do", "not", "go"]. + """ + + category = UdfCategories.STRING + + def execute(self, execution_context: ExecutionContext, arguments: StringArguments) -> list[str]: + return tokenize_text(arguments.s) diff --git a/osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_strings.py b/osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_strings.py index 38a31067..bcada009 100644 --- a/osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_strings.py +++ b/osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_strings.py @@ -17,6 +17,7 @@ StringSplit, StringStartsWith, StringStrip, + StringTokenize, StringToLower, StringToUpper, ) @@ -39,6 +40,7 @@ StringToUpper, StringExtractDomains, StringExtractURLs, + StringTokenize, ) ), ] @@ -319,3 +321,34 @@ def test_extract_urls(execute: ExecuteFunction, text: str, expected_result: List result: List[str] = data['Result'] assert len(expected_result) == len(result) assert set(expected_result) == set(result) + + +@pytest.mark.parametrize( + 'text,expected_result', + [ + ('the cat in the box', ['the', 'cat', 'in', 'the', 'box']), + ('the Cat in the bOx', ['the', 'cat', 'in', 'the', 'box']), + ("i'm going to the store", ["i'm", 'going', 'to', 'the', 'store']), + ('hello. where are you going? over here!', ['hello', 'where', 'are', 'you', 'going', 'over', 'here']), + ('hello123world', ['hello123world']), + ('test 456 test', ['test', '456', 'test']), + ('the cat', ['the', 'cat']), + ('hello\\tworld\\ntest', ['hello', 'world', 'test']), + ('hello, world!', ['hello', 'world']), + ('end. start', ['end', 'start']), + ('café résumé', ['café', 'résumé']), + ('donʼt', ["don't"]), # curly apostrophe (u02bc) + ('don’t', ["don't"]), # curly apostrophe (u2019) + ("cat's", ["cat's"]), + ("''hello", ['hello']), + ("test''test", ['test', 'test']), + ], +) +def test_tokenize(execute: ExecuteFunction, text: str, expected_result: List[str]) -> None: + data: Dict[str, Any] = execute(f""" + Result = StringTokenize(s="{text}") + """) + + result: List[str] = data['Result'] + assert len(expected_result) == len(result) + assert expected_result == result diff --git a/osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_unicode_censored.py b/osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_unicode_censored.py new file mode 100644 index 00000000..725a8d56 --- /dev/null +++ b/osprey_worker/src/osprey/engine/stdlib/udfs/tests/test_unicode_censored.py @@ -0,0 +1,215 @@ +import pytest +from osprey.engine.conftest import ExecuteFunction +from osprey.engine.stdlib.udfs.unicode_censored import StringCheckCensored +from osprey.engine.udf.registry import UDFRegistry + +pytestmark = [ + pytest.mark.use_udf_registry( + UDFRegistry.with_udfs( + StringCheckCensored, + ) + ), +] + + +class TestCheckCensoredUDF: + """Tests for the StringCheckCensored UDF.""" + + def test_basic_match(self, execute: ExecuteFunction): + data = execute(""" + Result = StringCheckCensored(s="cat", pattern="cat") + """) + assert data['Result'] is True + + def test_censored_match(self, execute: ExecuteFunction): + data = execute(""" + Result = StringCheckCensored(s="c@t", pattern="cat") + """) + assert data['Result'] is True + + def test_no_match(self, execute: ExecuteFunction): + data = execute(""" + Result = StringCheckCensored(s="dog", pattern="cat") + """) + assert data['Result'] is False + + def test_unicode_lookalike_match(self, execute: ExecuteFunction): + # using cyrillic 'а' which looks like latin 'a' + data = execute(""" + Result = StringCheckCensored(s="cаt", pattern="cat") + """) + assert data['Result'] is True + + data = execute(""" + Result = StringCheckCensored(s="𝒞𝞪𝔗", pattern="cat") + """) + assert data['Result'] is True + + def test_plural_option_enabled(self, execute: ExecuteFunction): + data = execute(""" + WithPlural = StringCheckCensored(s="cats", pattern="cat", plurals=True) + WithoutPlural = StringCheckCensored(s="cats", pattern="cat", plurals=False) + """) + assert data['WithPlural'] is True + assert data['WithoutPlural'] is False + + def test_substring_option(self, execute: ExecuteFunction): + data = execute(""" + WithSubstring = StringCheckCensored(s="concatenate", pattern="cat", substrings=True) + WithoutSubstring = StringCheckCensored(s="concatenate", pattern="cat", substrings=False) + """) + assert data['WithSubstring'] is True + assert data['WithoutSubstring'] is False + + def test_must_be_censored_option(self, execute: ExecuteFunction): + data = execute(""" + PlainText = StringCheckCensored(s="cat", pattern="cat", must_be_censored=True) + CensoredText = StringCheckCensored(s="c@t", pattern="cat", must_be_censored=True) + """) + assert data['PlainText'] is False + assert data['CensoredText'] is True + + def test_must_be_censored_with_surrounding_test(self, execute: ExecuteFunction): + data = execute(""" + PlainText = StringCheckCensored(s="the cat sat", pattern="cat", must_be_censored=True) + CensoredText = StringCheckCensored(s="the c@t sat", pattern="cat", must_be_censored=True) + """) + assert data['PlainText'] is False + assert data['CensoredText'] is True + + def test_case_insensitive(self, execute: ExecuteFunction): + data = execute(""" + Upper = StringCheckCensored(s="CAT", pattern="cat") + Mixed = StringCheckCensored(s="CaT", pattern="cat") + """) + assert data['Upper'] is True + assert data['Mixed'] is True + + def test_with_special_chars_in_pattern(self, execute: ExecuteFunction): + data = execute(""" + Result = StringCheckCensored(s="a.b", pattern="a.b") + """) + assert data['Result'] is True + + def test_empty_string(self, execute: ExecuteFunction): + data = execute(""" + Result = StringCheckCensored(s="", pattern="cat") + """) + assert data['Result'] is False + + @pytest.mark.parametrize( + 'input_str,pattern,expected', + [ + ('hello world', 'hello', True), + ('h3ll0', 'hello', True), + ('h e l l o', 'hello', False), + ('HELLO', 'hello', True), + ('dog', 'cat', False), + ('', 'test', False), + ('c@t', 'cat', True), + ('h4ck3r', 'hacker', True), + ('p@$$w0rd', 'password', True), + ('t35t', 'test', True), + ('1337', 'leet', True), + ('n00b', 'noob', True), + ('ph1sh', 'phish', True), + ('саt', 'cat', True), + ('руthоn', 'python', True), + ('НЕLLО', 'hello', True), + ('Ηello', 'hello', True), + ('Αpple', 'apple', True), + ('Βank', 'bank', True), + ('Κing', 'king', True), + ('Νice', 'nice', True), + ('Οpen', 'open', True), + ('Ρython', 'python', True), + ('Τest', 'test', True), + ('Χmas', 'xmas', True), + ('Υes', 'yes', True), + ('Ζero', 'zero', True), + ('𝐜𝐚𝐭', 'cat', True), + ('𝑐𝑎𝑡', 'cat', True), + ('𝒄𝒂𝒕', 'cat', True), + ('𝓬𝓪𝓽', 'cat', True), + ('𝔠𝔞𝔱', 'cat', True), + ('𝕔𝕒𝕥', 'cat', True), + ('𝖈𝖆𝖙', 'cat', True), + ('𝗰𝗮𝘁', 'cat', True), + ('𝘤𝘢𝘵', 'cat', True), + ('𝙘𝙖𝙩', 'cat', True), + ('𝚌𝚊𝚝', 'cat', True), + ('cat', 'cat', True), + ('hello', 'hello', True), + ('HELLO', 'hello', True), + ('ᑕᗩT', 'cat', True), + ('ꓚꓮT', 'cat', True), + ('ⲤⲀT', 'cat', True), + ('ԁоg', 'dog', True), + ('bаnk', 'bank', True), + ('pаypаl', 'paypal', True), + ('аmаzоn', 'amazon', True), + ('s3cur1ty', 'security', True), + ('4dm1n', 'admin', True), + ('r00t', 'root', True), + ('z3r0', 'zero', True), + ('0n3', 'one', True), + ('tw0', 'two', True), + ('с@т', 'cat', True), + ('ρ@$$ωθrd', 'password', True), + ('𝕙𝕒𝕔𝕜', 'hack', True), + ('հello', 'hello', True), + ('ոice', 'nice', True), + ('քhone', 'phone', True), + ('ցame', 'game', True), + ('(at', 'cat', True), + (' str: + """ + Normalize the string with unicodedata then translate any of the homoglyphs we have in our table above + """ + s = unicodedata.normalize('NFKC', s) + s = s.translate(_HOMOGLYPH_TABLE) + return s.lower() + + +@lru_cache(maxsize=10_000) +def _build_pattern(token: str, include_plural: bool, include_substrings: bool) -> re.Pattern[str]: + regex = '' + + # start off by adding a word boundary check if we dont want substrings + if not include_substrings: + regex += r'(?:^|\W)' + + regex += '(?P' + + # add all of the leet characters to the regex + for index, char in enumerate(token): + char_class = [char] + if char in _LEET_CHARS: + char_class.extend(_LEET_CHARS[char]) + + regex += '[' + ''.join(re.escape(c) for c in char_class) + ']' + + if index < len(token) - 1: + regex += _SEPARATOR_PATTERN + + # optionally allow for plurals + if include_plural: + regex += _SEPARATOR_PATTERN + regex += '[s$5]?' + + regex += ')' + + # end in a word boundary if we dont want substrings + if not include_substrings: + regex += r'(?:\W|$)' + + return re.compile(regex, re.IGNORECASE) + + +def _strip_separators(s: str) -> str: + return ''.join(c for c in s if c not in _SEPARATOR_CHARS) + + +class StringCheckCensoredArguments(StringArguments): + pattern: str + """ + The string to create a regex pattern for. + """ + + plurals: bool = False + """ + Whether to check for plurals of the string as well. I.e. if the input is 'cat', match both 'cat' and 'cats'. + + Default: False + """ + + substrings: bool = False + """ + Whether to check substrings of the input string. I.e. 'concatenate' would match the pattern created for 'cat'. + + Default: False + """ + + must_be_censored: bool = False + """ + Whether a string must be censored to return True. For example, 'cat' itself would return false but 'c@t' + would return true. + + Default: False + """ + + +class StringCheckCensored(UDFBase[StringCheckCensoredArguments, bool]): + """ + Checks a given string, check against another string's censored regex. + + For example, if given the input pattern of "cat", will match tokens like "c4t", "<@t", or "c___a___t". + + It is recommended to use individual tokens to check against. For example, you should not attempt to match against + a long string, but rather individual words within that string. + """ + + category = UdfCategories.STRING + + def execute(self, execution_context: ExecutionContext, arguments: StringCheckCensoredArguments) -> bool: + normalized_token = _normalize_for_match(arguments.s) + + pattern_lower = arguments.pattern.lower() + regex = _build_pattern(pattern_lower, arguments.plurals, arguments.substrings) + + match = regex.search(normalized_token) + if match is None: + return False + + if arguments.must_be_censored: + matched_word = match.group('word') + matched_stripped = _strip_separators(matched_word).lower() + if matched_stripped == pattern_lower: + return False + + return True diff --git a/osprey_worker/src/osprey/worker/_stdlibplugin/udf_register.py b/osprey_worker/src/osprey/worker/_stdlibplugin/udf_register.py index caa3ab70..8359b5be 100644 --- a/osprey_worker/src/osprey/worker/_stdlibplugin/udf_register.py +++ b/osprey_worker/src/osprey/worker/_stdlibplugin/udf_register.py @@ -41,6 +41,7 @@ StringSplit, StringStartsWith, StringStrip, + StringTokenize, StringToLower, StringToUpper, ) @@ -58,6 +59,7 @@ ) from osprey.engine.stdlib.udfs.time_delta import TimeDelta from osprey.engine.stdlib.udfs.time_since import TimeSince +from osprey.engine.stdlib.udfs.unicode_censored import StringCheckCensored from osprey.engine.stdlib.udfs.verdicts import DeclareVerdict from osprey.engine.udf.base import UDFBase from osprey.worker.adaptor.plugin_manager import hookimpl_osprey @@ -108,12 +110,14 @@ def register_udfs() -> Sequence[Type[UDFBase[Any, Any]]]: HashSha1, HashSha256, HashSha512, + StringCheckCensored, StringLength, StringToLower, StringToUpper, StringStartsWith, StringEndsWith, StringStrip, + StringTokenize, StringRStrip, StringLStrip, StringReplace,