@@ -136,28 +136,31 @@ def test_resnet18_torch_exec_ops(ir):
136
136
not importlib .util .find_spec ("torchvision" ),
137
137
"torchvision is not installed" ,
138
138
)
139
- def test_mobilenet_v2 (ir ):
140
- model = models .mobilenet_v2 (pretrained = True ).eval ().to ("cuda" )
141
- input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
139
+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ])
140
+ def test_mobilenet_v2 (ir , dtype ):
141
+ model = models .mobilenet_v2 (pretrained = True ).eval ().to ("cuda" ).to (dtype )
142
+ input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" ).to (dtype )
142
143
143
144
compile_spec = {
144
145
"inputs" : [
145
- torchtrt .Input (
146
- input .shape , dtype = torch .float , format = torch .contiguous_format
147
- )
146
+ torchtrt .Input (input .shape , dtype = dtype , format = torch .contiguous_format )
148
147
],
149
148
"device" : torchtrt .Device ("cuda:0" ),
150
- "enabled_precisions" : {torch .float },
151
149
"ir" : ir ,
152
150
"pass_through_build_failures" : True ,
153
151
"optimization_level" : 1 ,
154
152
"min_block_size" : 10 ,
155
153
"cache_built_engines" : False ,
156
154
"reuse_cached_engines" : False ,
155
+ "use_explicit_typing" : True ,
157
156
}
158
157
159
158
trt_mod = torchtrt .compile (model , ** compile_spec )
160
- cos_sim = cosine_similarity (model (input ), trt_mod (input ))
159
+ pyt_output = model (input )
160
+ trt_output = trt_mod (input )
161
+ assert pyt_output .dtype == trt_output .dtype
162
+ assert pyt_output .dtype == dtype
163
+ cos_sim = cosine_similarity (pyt_output , trt_output )
161
164
assertions .assertTrue (
162
165
cos_sim > COSINE_THRESHOLD ,
163
166
msg = f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
@@ -172,28 +175,36 @@ def test_mobilenet_v2(ir):
172
175
not importlib .util .find_spec ("timm" ) or not importlib .util .find_spec ("torchvision" ),
173
176
"timm or torchvision not installed" ,
174
177
)
175
- def test_efficientnet_b0 (ir ):
176
- model = timm .create_model ("efficientnet_b0" , pretrained = True ).eval ().to ("cuda" )
177
- input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
178
+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ])
179
+ def test_efficientnet_b0 (ir , dtype ):
180
+ model = (
181
+ timm .create_model ("efficientnet_b0" , pretrained = True )
182
+ .eval ()
183
+ .to ("cuda" )
184
+ .to (dtype )
185
+ )
186
+ input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" ).to (dtype )
178
187
179
188
compile_spec = {
180
189
"inputs" : [
181
- torchtrt .Input (
182
- input .shape , dtype = torch .float , format = torch .contiguous_format
183
- )
190
+ torchtrt .Input (input .shape , dtype = dtype , format = torch .contiguous_format )
184
191
],
185
192
"device" : torchtrt .Device ("cuda:0" ),
186
- "enabled_precisions" : {torch .float },
187
193
"ir" : ir ,
188
194
"pass_through_build_failures" : True ,
189
195
"optimization_level" : 1 ,
190
196
"min_block_size" : 10 ,
191
197
"cache_built_engines" : False ,
192
198
"reuse_cached_engines" : False ,
199
+ "use_explicit_typing" : True ,
193
200
}
194
201
195
202
trt_mod = torchtrt .compile (model , ** compile_spec )
196
- cos_sim = cosine_similarity (model (input ), trt_mod (input ))
203
+ pyt_output = model (input )
204
+ trt_output = trt_mod (input )
205
+ assert pyt_output .dtype == trt_output .dtype
206
+ assert pyt_output .dtype == dtype
207
+ cos_sim = cosine_similarity (pyt_output , trt_output )
197
208
assertions .assertTrue (
198
209
cos_sim > COSINE_THRESHOLD ,
199
210
msg = f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
@@ -208,10 +219,11 @@ def test_efficientnet_b0(ir):
208
219
not importlib .util .find_spec ("transformers" ),
209
220
"transformers is required to run this test" ,
210
221
)
211
- def test_bert_base_uncased (ir ):
222
+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ])
223
+ def test_bert_base_uncased (ir , dtype ):
212
224
from transformers import BertModel
213
225
214
- model = BertModel .from_pretrained ("bert-base-uncased" ).cuda ().eval ()
226
+ model = BertModel .from_pretrained ("bert-base-uncased" ).cuda ().eval (). to ( dtype )
215
227
input = torch .randint (0 , 2 , (1 , 14 ), dtype = torch .int32 ).to ("cuda" )
216
228
input2 = torch .randint (0 , 2 , (1 , 14 ), dtype = torch .int32 ).to ("cuda" )
217
229
@@ -229,21 +241,23 @@ def test_bert_base_uncased(ir):
229
241
),
230
242
],
231
243
"device" : torchtrt .Device ("cuda:0" ),
232
- "enabled_precisions" : {torch .float },
233
244
"truncate_double" : True ,
234
245
"ir" : ir ,
235
246
"pass_through_build_failures" : True ,
236
247
"optimization_level" : 1 ,
237
248
"min_block_size" : 15 ,
238
249
"cache_built_engines" : False ,
239
250
"reuse_cached_engines" : False ,
251
+ "use_explicit_typing" : True ,
240
252
}
241
253
trt_mod = torchtrt .compile (model , ** compile_spec )
242
254
243
255
model_outputs = model (input , input2 )
244
256
trt_model_outputs = trt_mod (input , input2 )
245
257
for key in model_outputs .keys ():
246
258
out , trt_out = model_outputs [key ], trt_model_outputs [key ]
259
+ assert out .dtype == trt_out .dtype
260
+ assert out .dtype == dtype
247
261
cos_sim = cosine_similarity (out , trt_out )
248
262
assertions .assertTrue (
249
263
cos_sim > COSINE_THRESHOLD ,
0 commit comments