2
2
from dataclasses import dataclass
3
3
from functools import cached_property
4
4
from mmap import ACCESS_READ , mmap
5
- from typing import IO , override
5
+ from operator import itemgetter
6
+ from typing import IO , final , override
6
7
7
8
import more_itertools as mit
8
9
import numpy .typing as npt
9
10
import torch
10
- from jaxtyping import Shaped
11
11
from mcap .data_stream import ReadDataStream
12
12
from mcap .decoder import DecoderFactory
13
13
from mcap .opcode import Opcode
@@ -32,6 +32,7 @@ class MessageIndex:
32
32
message_length : int
33
33
34
34
35
+ @final
35
36
class McapTensorSource (TensorSource ):
36
37
@validate_call (config = BaseModel .model_config )
37
38
def __init__ (
@@ -47,8 +48,8 @@ def __init__(
47
48
with bound_contextvars (
48
49
path = path .as_posix (), topic = topic , message_decoder_factory = decoder_factory
49
50
):
50
- self ._path : FilePath = path
51
- self ._validate_crcs : bool = validate_crcs
51
+ self ._path = path
52
+ self ._validate_crcs = validate_crcs
52
53
53
54
summary = SeekingReader (
54
55
stream = self ._file , validate_crcs = self ._validate_crcs
@@ -73,13 +74,14 @@ def __init__(
73
74
logger .error (msg := "missing message decoder" )
74
75
raise RuntimeError (msg )
75
76
76
- self ._message_decoder : Callable [[ bytes ], object ] = message_decoder
77
- self ._chunk_indexes : tuple [ ChunkIndex , ...] = tuple (
77
+ self ._message_decoder = message_decoder
78
+ self ._chunk_indexes = tuple (
78
79
chunk_index
79
80
for chunk_index in summary .chunk_indexes
80
81
if self ._channel .id in chunk_index .message_index_offsets
81
82
)
82
- self ._decoder : Callable [[bytes ], npt .ArrayLike ] = decoder
83
+ self ._decoder = decoder
84
+ self ._mmap = None
83
85
84
86
@property
85
87
def _file (self ) -> IO [bytes ]:
@@ -89,42 +91,55 @@ def _file(self) -> IO[bytes]:
89
91
90
92
case None | mmap (closed = True ):
91
93
with self ._path .open ("rb" ) as f :
92
- self ._mmap : mmap = mmap (
93
- fileno = f .fileno (), length = 0 , access = ACCESS_READ
94
- )
94
+ self ._mmap = mmap (fileno = f .fileno (), length = 0 , access = ACCESS_READ )
95
95
96
96
case _:
97
97
raise RuntimeError
98
98
99
99
return self ._mmap # pyright: ignore[reportReturnType]
100
100
101
101
@override
102
- def __getitem__ (self , indexes : Iterable [int ]) -> Shaped [Tensor , "b h w c" ]:
103
- frames : Mapping [int , npt .ArrayLike ] = {}
104
-
105
- message_indexes_by_chunk_start_offset : Mapping [
106
- int , Iterable [tuple [int , MessageIndex ]]
107
- ] = mit .map_reduce (
108
- zip (indexes , (self ._message_indexes [idx ] for idx in indexes ), strict = True ),
109
- keyfunc = lambda x : x [1 ].chunk_start_offset ,
110
- )
102
+ def __getitem__ (self , indexes : int | Iterable [int ]) -> Tensor :
103
+ match indexes :
104
+ case Iterable ():
105
+ arrays : Mapping [int , npt .ArrayLike ] = {}
106
+ message_indexes = (self ._message_indexes [idx ] for idx in indexes )
107
+ indexes_by_chunk_start_offset = mit .map_reduce (
108
+ zip (indexes , message_indexes , strict = True ),
109
+ keyfunc = lambda x : x [1 ].chunk_start_offset ,
110
+ )
111
+
112
+ for chunk_start_offset , chunk_indexes in sorted (
113
+ indexes_by_chunk_start_offset .items (), key = itemgetter (0 )
114
+ ):
115
+ _ = self ._file .seek (chunk_start_offset + 1 + 8 )
116
+ chunk = Chunk .read (ReadDataStream (self ._file ))
117
+ stream , _ = get_chunk_data_stream (
118
+ chunk , validate_crc = self ._validate_crcs
119
+ )
120
+ for index , message_index in sorted (
121
+ chunk_indexes , key = lambda x : x [1 ].message_start_offset
122
+ ):
123
+ stream .read (message_index .message_start_offset - stream .count ) # pyright: ignore[reportUnusedCallResult]
124
+ message = Message .read (stream , message_index .message_length )
125
+ decoded_message = self ._message_decoder (message .data )
126
+ arrays [index ] = self ._decoder (decoded_message .data )
111
127
112
- for (
113
- chunk_start_offset ,
114
- chunk_message_indexes ,
115
- ) in message_indexes_by_chunk_start_offset .items ():
116
- self ._file .seek (chunk_start_offset + 1 + 8 ) # pyright: ignore[reportUnusedCallResult]
117
- chunk = Chunk .read (ReadDataStream (self ._file ))
118
- stream , _ = get_chunk_data_stream (chunk , validate_crc = self ._validate_crcs )
119
- for frame_index , message_index in sorted (
120
- chunk_message_indexes , key = lambda x : x [1 ].message_start_offset
121
- ):
122
- stream .read (message_index .message_start_offset - stream .count ) # pyright: ignore[reportUnusedCallResult]
123
- message = Message .read (stream , message_index .message_length )
128
+ tensors = [torch .from_numpy (arrays [idx ]) for idx in indexes ] # pyright: ignore[reportUnknownMemberType]
129
+
130
+ return torch .stack (tensors )
131
+
132
+ case _:
133
+ message_index = self ._message_indexes [indexes ]
134
+ _ = self ._file .seek (message_index .chunk_start_offset + 1 + 8 )
135
+ chunk = Chunk .read (ReadDataStream (self ._file ))
136
+ stream , _ = get_chunk_data_stream (chunk , self ._validate_crcs )
137
+ _ = stream .read (message_index .message_start_offset - stream .count )
138
+ message = Message .read (stream , length = message_index .message_length )
124
139
decoded_message = self ._message_decoder (message .data )
125
- frames [ frame_index ] = self ._decoder (decoded_message .data ) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
140
+ array = self ._decoder (decoded_message .data )
126
141
127
- return torch .stack ([ torch . from_numpy (frames [ idx ]) for idx in indexes ] ) # pyright: ignore[reportUnknownMemberType]
142
+ return torch .from_numpy (array ) # pyright: ignore[reportUnknownMemberType]
128
143
129
144
@override
130
145
def __len__ (self ) -> int :
0 commit comments