Skip to content

Commit b816dab

Browse files
authored
feat(cython): Python Stable ABI Friendliness (#72)
1 parent 81cfdee commit b816dab

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

python/mlc/_cython/core.pyx

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,17 @@ from libcpp.vector cimport vector
66
from libc.stdint cimport int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t
77
from libc.stdlib cimport malloc, free
88
from numbers import Integral, Number
9-
from cpython cimport Py_DECREF, Py_INCREF, PyCapsule_IsValid, PyCapsule_GetPointer, PyCapsule_SetName, PyCapsule_New
109
from . import base
10+
from cpython.pycapsule cimport (
11+
PyCapsule_IsValid,
12+
PyCapsule_GetPointer,
13+
PyCapsule_SetName,
14+
PyCapsule_New,
15+
)
16+
17+
cdef extern from "Python.h":
18+
void Py_IncRef(object)
19+
void Py_DecRef(object)
1120

1221
Ptr = base.Ptr
1322
PyCode_NewEmpty = ctypes.pythonapi.PyCode_NewEmpty
@@ -417,21 +426,17 @@ cdef class PyAny:
417426
raise e.with_traceback(None)
418427
return _any_c2py_no_inc_ref(c_ret)
419428

420-
cdef class Str(str):
421-
cdef MLCAny _mlc_any
422-
__slots__ = ()
423-
424-
def __cinit__(self):
425-
self._mlc_any = _MLCAnyNone()
426429

427-
def __init__(self, value):
428-
cdef str value_unicode = self
429-
cdef bytes value_c = str_py2c(value_unicode)
430-
self._mlc_any = _MLCAnyRawStr(value_c)
431-
_check_error(_C_AnyInplaceViewToOwned(&self._mlc_any))
430+
class Str(str):
431+
__slots__ = ("_pyany",)
432432

433-
def __dealloc__(self):
434-
_check_error(_C_AnyDecRef(&self._mlc_any))
433+
def __new__(cls, value: str):
434+
cdef PyAny pyany = PyAny()
435+
self = super().__new__(cls, value)
436+
self._pyany = pyany
437+
pyany._mlc_any = _MLCAnyRawStr(str_py2c(value))
438+
_check_error(_C_AnyInplaceViewToOwned(&pyany._mlc_any))
439+
return self
435440

436441
def __reduce__(self):
437442
return (Str, (str(self),))
@@ -541,7 +546,7 @@ cdef inline object _any_c2py_no_inc_ref(const MLCAny x):
541546
cdef int32_t type_index = x.type_index
542547
cdef MLCStr* mlc_str = NULL
543548
cdef PyAny any_ret
544-
cdef Str str_ret
549+
cdef object str_ret
545550
if type_index == kMLCNone:
546551
return None
547552
elif type_index == kMLCBool:
@@ -556,8 +561,10 @@ cdef inline object _any_c2py_no_inc_ref(const MLCAny x):
556561
return str_c2py(x.v.v_str)
557562
elif type_index == kMLCStr:
558563
mlc_str = <MLCStr*>(x.v.v_obj)
564+
any_ret = PyAny()
565+
any_ret._mlc_any = x
559566
str_ret = Str.__new__(Str, str_c2py(mlc_str.data[:mlc_str.length]))
560-
str_ret._mlc_any = x
567+
str_ret._pyany = any_ret
561568
return str_ret
562569
elif type_index == kMLCOpaque:
563570
return <object>((<MLCOpaque*>(x.v.v_obj)).handle)
@@ -572,7 +579,7 @@ cdef inline object _any_c2py_inc_ref(MLCAny x):
572579
cdef int32_t type_index = x.type_index
573580
cdef MLCStr* mlc_str = NULL
574581
cdef PyAny any_ret
575-
cdef Str str_ret
582+
cdef object str_ret
576583
if type_index == kMLCNone:
577584
return None
578585
elif type_index == kMLCBool:
@@ -587,8 +594,10 @@ cdef inline object _any_c2py_inc_ref(MLCAny x):
587594
return str_c2py(x.v.v_str)
588595
elif type_index == kMLCStr:
589596
mlc_str = <MLCStr*>(x.v.v_obj)
597+
any_ret = PyAny()
598+
any_ret._mlc_any = x
590599
str_ret = Str.__new__(Str, str_c2py(mlc_str.data[:mlc_str.length]))
591-
str_ret._mlc_any = x
600+
str_ret._pyany = any_ret
592601
_check_error(_C_AnyIncRef(&x))
593602
return str_ret
594603
elif type_index == kMLCOpaque:
@@ -624,11 +633,11 @@ cdef inline PyAny _pyany_from_opaque(object x):
624633
args[0] = _MLCAnyPtr(<uint64_t>(<void*>x))
625634
args[1] = _MLCAnyPtr(<uint64_t>(<void*>_pyobj_deleter))
626635
args[2] = _MLCAnyRawStr(type_name)
627-
Py_INCREF(x)
636+
Py_IncRef(x)
628637
try:
629638
_func_call_impl_with_c_args(_OPAQUE_INIT, 3, args, &ret._mlc_any)
630639
except: # no-cython-lint
631-
Py_DECREF(x)
640+
Py_DecRef(x)
632641
raise
633642
return ret
634643

@@ -677,7 +686,7 @@ cdef inline MLCAny _any_py2c(object x, list temporary_storage):
677686
elif isinstance(x, PyAny):
678687
y = (<PyAny>x)._mlc_any
679688
elif isinstance(x, Str):
680-
y = (<Str>x)._mlc_any
689+
y = (<PyAny>(x._pyany))._mlc_any
681690
elif isinstance(x, bool):
682691
y = _MLCAnyBool(<bint>x)
683692
elif isinstance(x, Integral):
@@ -729,7 +738,7 @@ cdef inline MLCAny _any_py2c_dict(tuple x, list temporary_storage):
729738
cdef void _pyobj_deleter(void* handle) noexcept nogil:
730739
with gil:
731740
try:
732-
Py_DECREF(<object>(handle))
741+
Py_DecRef(<object>(handle))
733742
except Exception as exception:
734743
# TODO(@junrushao): Will need to handle exceptions more gracefully
735744
print(f"Error in _pyobj_deleter: {exception}")
@@ -767,7 +776,7 @@ cdef inline int32_t _func_safe_call_impl(
767776

768777
cdef inline PyAny _pyany_from_func(object py_func):
769778
cdef PyAny ret = PyAny()
770-
Py_INCREF(py_func)
779+
Py_IncRef(py_func)
771780
_check_error(_C_FuncCreate(<void*>(py_func), _pyobj_deleter, _func_safe_call, &ret._mlc_any))
772781
return ret
773782

0 commit comments

Comments
 (0)