diff --git a/ai_edge_torch/_convert/test/test_convert.py b/ai_edge_torch/_convert/test/test_convert.py index ebb0d8597..70523d47f 100644 --- a/ai_edge_torch/_convert/test/test_convert.py +++ b/ai_edge_torch/_convert/test/test_convert.py @@ -100,6 +100,15 @@ def test_convert_resnet18(self): model_coverage.compare_tflite_torch(edge_model, torch_module, args) ) + def test_convert_mobilenet_v2(self): + args = (torch.randn(4, 3, 224, 224),) + torch_module = torchvision.models.mobilenet_v2().eval() + edge_model = ai_edge_torch.convert(torch_module, args) + + self.assertTrue( + model_coverage.compare_tflite_torch(edge_model, torch_module, args) + ) + def test_signature_args_ordering(self): """Tests conversion of a model with more than 10 arguments."""