Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/tvm_ffi/cython/string.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class String(str, PyNativeObject):
val._tvm_ffi_cached_object = None
return val

def __reduce_ex__(self, protocol):
return (str, (str(self),))

# pylint: disable=no-self-argument
def __from_tvm_ffi_object__(cls, obj: Any) -> "String":
"""Construct a ``String`` from an FFI object (internal)."""
Expand All @@ -80,6 +83,9 @@ class Bytes(bytes, PyNativeObject):
val._tvm_ffi_cached_object = None
return val

def __reduce_ex__(self, protocol):
return (bytes, (bytes(self),))

# pylint: disable=no-self-argument
def __from_tvm_ffi_object__(cls, obj: Any) -> "Bytes":
"""Construct ``Bytes`` from an FFI object (internal)."""
Expand Down
20 changes: 19 additions & 1 deletion tests/python/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pickle

import tvm_ffi
import tvm_ffi.testing


def test_string() -> None:
Expand All @@ -33,6 +34,15 @@ def test_string() -> None:

s4 = pickle.loads(pickle.dumps(s))
assert s4 == "hello"
assert type(s4) is str

cached = fecho("x" * 200)
assert isinstance(cached, tvm_ffi.core.String)
assert cached._tvm_ffi_cached_object is not None

cached_roundtrip = pickle.loads(pickle.dumps(cached))
assert cached_roundtrip == cached
assert type(cached_roundtrip) is str


def test_bytes() -> None:
Expand All @@ -52,7 +62,15 @@ def test_bytes() -> None:

b5 = pickle.loads(pickle.dumps(b))
assert b5 == b"hello"
assert isinstance(b5, tvm_ffi.core.Bytes)
assert type(b5) is bytes

cached = fecho(b"x" * 200)
assert isinstance(cached, tvm_ffi.core.Bytes)
assert cached._tvm_ffi_cached_object is not None

cached_roundtrip = pickle.loads(pickle.dumps(cached))
assert cached_roundtrip == cached
assert type(cached_roundtrip) is bytes


def test_string_find_substr() -> None:
Expand Down
Loading