|
2 | 2 | import os |
3 | 3 | import zlib |
4 | 4 | from pathlib import Path |
5 | | -from typing import Any, Dict, List, Optional, Union |
| 5 | +from typing import Any, Union |
6 | 6 |
|
7 | 7 | import ase.db.core |
8 | 8 | import ase.db.row |
|
11 | 11 | from ase.calculators.singlepoint import SinglePointCalculator |
12 | 12 | import lmdb |
13 | 13 | import numpy as np |
| 14 | +import orjson |
14 | 15 |
|
15 | 16 | logger = logging.getLogger(__name__) |
16 | 17 |
|
@@ -233,17 +234,34 @@ def get_atoms(self, idx: int) -> Atoms: |
233 | 234 | def encode_object(obj: Any, compress=True, json_encode=True) -> bytes: |
234 | 235 | """Encode object to compressed JSON.""" |
235 | 236 | if json_encode: |
236 | | - obj = encode(obj) |
| 237 | + try: |
| 238 | + # OPT_SERIALIZE_NUMPY handles numpy arrays directly |
| 239 | + obj_bytes = orjson.dumps(obj, option=orjson.OPT_SERIALIZE_NUMPY) |
| 240 | + except orjson.JSONEncodeError: |
| 241 | + # Fallback to standard ASE jsonio if orjson fails (e.g. for unsupported complex objects) |
| 242 | + obj_bytes = encode(obj).encode("utf-8") |
| 243 | + else: |
| 244 | + obj_bytes = obj.encode("utf-8") if isinstance(obj, str) else bytes(obj) |
| 245 | + |
237 | 246 | if compress: |
238 | | - return zlib.compress(obj.encode("utf-8")) |
239 | | - return obj.encode("utf-8") |
| 247 | + return zlib.compress(obj_bytes) |
| 248 | + return obj_bytes |
240 | 249 |
|
241 | 250 | def decode_bytestream(bytestream: bytes, decompress=True, json_decode=True) -> Any: |
242 | 251 | """Decode compressed JSON bytestream.""" |
243 | 252 | if decompress: |
244 | | - bytestream = zlib.decompress(bytestream).decode("utf-8") |
245 | | - else: |
246 | | - bytestream = bytestream.decode("utf-8") |
| 253 | + bytestream = zlib.decompress(bytestream) |
| 254 | + |
247 | 255 | if json_decode: |
248 | | - return decode(bytestream) |
249 | | - return bytestream |
| 256 | + # ASE's custom JSON encoder uses special keys like __ndarray__, __complex__, etc. |
| 257 | + # If the payload contains these, we must use ASE's decoder to reconstruct the objects. |
| 258 | + if b'__ndarray__' in bytestream or b'__complex__' in bytestream: |
| 259 | + return decode(bytestream.decode("utf-8")) |
| 260 | + |
| 261 | + try: |
| 262 | + return orjson.loads(bytestream) |
| 263 | + except orjson.JSONDecodeError: |
| 264 | + # Fallback to standard ASE jsonio if orjson fails |
| 265 | + return decode(bytestream.decode("utf-8")) |
| 266 | + |
| 267 | + return bytestream.decode("utf-8") |
0 commit comments