Skip to content

Commit b3aeedc

Browse files
committed
bug fix for corner case: single frame dataset
1 parent e19efb0 commit b3aeedc

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

deepmd/utils/data.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -854,26 +854,26 @@ def _load_single_data(
854854
mmap_obj = self._get_memmap(path)
855855
# Slice the single frame and make an in-memory copy for modification
856856
if mmap_obj.ndim == 0:
857-
# Scalar array
858-
data = mmap_obj.copy().astype(dtype, copy=False)
857+
# case: single frame data && non-atomic
858+
data = mmap_obj.copy().astype(dtype, copy=False).reshape(1, -1)
859859
elif mmap_obj.ndim == 1:
860-
# Single-frame file (shape: [ndof]); only frame_idx==0 is valid
860+
# case: single frame data && atomic
861861
if frame_idx != 0:
862862
raise IndexError(
863863
f"frame index {frame_idx} out of range for single-frame file: {path}"
864864
)
865-
data = mmap_obj.copy().astype(dtype, copy=False)
865+
data = mmap_obj.copy().astype(dtype, copy=False).reshape(1, -1)
866866
else:
867-
# Regular [nframes, ...]
868-
data = mmap_obj[frame_idx].copy().astype(dtype, copy=False)
867+
# case: multi-frame data
868+
data = mmap_obj[frame_idx].copy().astype(dtype, copy=False).reshape(1, -1)
869869

870870
try:
871871
if vv["atomic"]:
872872
# Handle type_sel logic
873873
if vv["type_sel"] is not None:
874874
sel_mask = np.isin(self.atom_type, vv["type_sel"])
875875

876-
if mmap_obj.shape[1] == natoms_sel * ndof:
876+
if data.shape[1] == natoms_sel * ndof:
877877
if vv["output_natoms_for_type_sel"]:
878878
tmp = np.zeros([natoms, ndof], dtype=data.dtype)
879879
# sel_mask needs to be applied to the original atom layout
@@ -882,7 +882,7 @@ def _load_single_data(
882882
else: # output is natoms_sel
883883
natoms = natoms_sel
884884
idx_map = idx_map_sel
885-
elif mmap_obj.shape[1] == natoms * ndof:
885+
elif data.shape[1] == natoms * ndof:
886886
data = data.reshape([natoms, ndof])
887887
if vv["output_natoms_for_type_sel"]:
888888
pass
@@ -892,7 +892,7 @@ def _load_single_data(
892892
natoms = natoms_sel
893893
else: # Shape mismatch error
894894
raise ValueError(
895-
f"The shape of the data {key} in {set_dir} has width {mmap_obj.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})"
895+
f"The shape of the data {key} in {set_dir} has width {data.shape[1]}, which doesn't match either ({natoms_sel * ndof}) or ({natoms * ndof})"
896896
)
897897

898898
# Handle special case for Hessian

0 commit comments

Comments
 (0)