diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index f652449f5289..35b58649f173 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -49,6 +49,7 @@ RTuple, RType, bool_rprimitive, + bytes_rprimitive, c_int_rprimitive, dict_rprimitive, int16_rprimitive, @@ -83,6 +84,11 @@ join_formatted_strings, tokenizer_format_call, ) +from mypyc.primitives.bytes_ops import ( + bytes_decode_ascii_strict, + bytes_decode_latin1_strict, + bytes_decode_utf8_strict, +) from mypyc.primitives.dict_ops import ( dict_items_op, dict_keys_op, @@ -740,6 +746,58 @@ def str_encode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> return None +@specialize_function("decode", bytes_rprimitive) +def bytes_decode_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: + if not isinstance(callee, MemberExpr): + return None + + encoding = "utf8" + errors = "strict" + + # Handle up to 2 arguments: decode([encoding], [errors]) + if len(expr.arg_kinds) > 0 and isinstance(expr.args[0], StrExpr): + if expr.arg_kinds[0] == ARG_NAMED: + if expr.arg_names[0] == "encoding": + encoding = expr.args[0].value + elif expr.arg_names[0] == "errors": + errors = expr.args[0].value + elif expr.arg_kinds[0] == ARG_POS: + encoding = expr.args[0].value + else: + return None + + if len(expr.arg_kinds) > 1 and isinstance(expr.args[1], StrExpr): + if expr.arg_kinds[1] == ARG_NAMED: + if expr.arg_names[1] == "encoding": + encoding = expr.args[1].value + elif expr.arg_names[1] == "errors": + errors = expr.args[1].value + elif expr.arg_kinds[1] == ARG_POS: + errors = expr.args[1].value + else: + return None + + if errors != "strict": + return None + + normalized = encoding.lower().replace("-", "").replace("_", "") + + if normalized in ("utf8", "utf", "u8", "cp65001"): + return builder.primitive_op( + bytes_decode_utf8_strict, [builder.accept(callee.expr)], expr.line + ) + elif normalized in ("ascii", "usascii", "646"): + return builder.primitive_op( + bytes_decode_ascii_strict, [builder.accept(callee.expr)], expr.line + ) + elif normalized in ("latin1", "latin", "iso88591", "cp819", "8859", "l1"): + return builder.primitive_op( + bytes_decode_latin1_strict, [builder.accept(callee.expr)], expr.line + ) + + return None + + @specialize_function("mypy_extensions.i64") def translate_i64(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS: diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 1f0cf4dd63d6..aca7a6b23e6a 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -764,6 +764,9 @@ CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index); PyObject *CPyBytes_Concat(PyObject *a, PyObject *b); PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); CPyTagged CPyBytes_Ord(PyObject *obj); +PyObject *CPy_DecodeUtf8(PyObject *bytes_obj, const char *errors); +PyObject *CPy_DecodeLatin1(PyObject *bytes_obj, const char *errors); +PyObject *CPy_DecodeAscii(PyObject *bytes_obj, const char *errors); int CPyBytes_Compare(PyObject *left, PyObject *right); diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 6ff34b021a9a..4f7652e1cd1c 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -162,3 +162,42 @@ CPyTagged CPyBytes_Ord(PyObject *obj) { PyErr_SetString(PyExc_TypeError, "ord() expects a character"); return CPY_INT_TAG; } + + +PyObject *CPy_DecodeUtf8(PyObject *bytes_obj, const char *errors) { + if (!PyBytes_Check(bytes_obj)) { + PyErr_SetString(PyExc_TypeError, "expected bytes object"); + return NULL; + } + + char *data = PyBytes_AS_STRING(bytes_obj); + Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj); + + return PyUnicode_DecodeUTF8(data, size, errors); +} + + +PyObject *CPy_DecodeLatin1(PyObject *bytes_obj, const char *errors) { + if (!PyBytes_Check(bytes_obj)) { + PyErr_SetString(PyExc_TypeError, "expected bytes object"); + return NULL; + } + + char *data = PyBytes_AS_STRING(bytes_obj); + Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj); + + return PyUnicode_DecodeLatin1(data, size, errors); +} + + +PyObject *CPy_DecodeAscii(PyObject *bytes_obj, const char *errors) { + if (!PyBytes_Check(bytes_obj)) { + PyErr_SetString(PyExc_TypeError, "expected bytes object"); + return NULL; + } + + char *data = PyBytes_AS_STRING(bytes_obj); + Py_ssize_t size = PyBytes_GET_SIZE(bytes_obj); + + return PyUnicode_DecodeASCII(data, size, errors); +} diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index 1afd196cff84..3ad920bc9480 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -18,6 +18,7 @@ ERR_NEG_INT, binary_op, custom_op, + custom_primitive_op, function_op, load_address_op, method_op, @@ -107,3 +108,27 @@ c_function_name="CPyBytes_Ord", error_kind=ERR_MAGIC, ) + +bytes_decode_utf8_strict = custom_primitive_op( + name="decode", + arg_types=[bytes_rprimitive, str_rprimitive], + return_type=str_rprimitive, + c_function_name="CPy_DecodeUtf8", + error_kind=ERR_MAGIC, +) + +bytes_decode_latin1_strict = custom_primitive_op( + name="decode_latin1", + arg_types=[bytes_rprimitive, str_rprimitive], + return_type=str_rprimitive, + c_function_name="CPy_DecodeLatin1", + error_kind=ERR_MAGIC, +) + +bytes_decode_ascii_strict = custom_primitive_op( + name="decode_ascii", + arg_types=[bytes_rprimitive, str_rprimitive], + return_type=str_rprimitive, + c_function_name="CPy_DecodeAscii", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 476c5ac59f48..1a11442ac91a 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -185,3 +185,38 @@ L0: r10 = CPyBytes_Build(2, var, r9) b4 = r10 return 1 + +[case testDecodeBytes] +def f(b: bytes) -> None: + b.decode() + b.decode('utf8') + b.decode('utf-8', 'strict') + b.decode('utf-8', 'strict') + b.decode('latin1', 'strict') + b.decode('ascii') + b.decode('latin-1') + b.decode('utf-8', 'ignore') + b.decode('ascii', 'replace') + b.decode('latin1', 'ignore') +[out] +def f(b): + b :: bytes + r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15 :: str +L0: + r0 = CPy_DecodeUtf8(b) + r1 = CPy_DecodeUtf8(b) + r2 = CPy_DecodeUtf8(b) + r3 = CPy_DecodeUtf8(b) + r4 = CPy_DecodeLatin1(b) + r5 = CPy_DecodeAscii(b) + r6 = CPy_DecodeLatin1(b) + r7 = 'utf-8' + r8 = 'ignore' + r9 = CPy_Decode(b, r7, r8) + r10 = 'ascii' + r11 = 'replace' + r12 = CPy_Decode(b, r10, r11) + r13 = 'latin1' + r14 = 'ignore' + r15 = CPy_Decode(b, r13, r14) + return 1