Skip to content

Commit dd16acb

Browse files
authored
set test_torchscript = False for Blip2 testing (huggingface#35972)
* just skip * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 0a9923a commit dd16acb

File tree

1 file changed

+3
-224
lines changed

1 file changed

+3
-224
lines changed

tests/models/blip_2/test_modeling_blip_2.py

Lines changed: 3 additions & 224 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"""Testing suite for the PyTorch BLIP-2 model."""
1616

1717
import inspect
18-
import os
1918
import tempfile
2019
import unittest
2120

@@ -36,7 +35,7 @@
3635
slow,
3736
torch_device,
3837
)
39-
from transformers.utils import is_torch_available, is_torch_sdpa_available, is_vision_available
38+
from transformers.utils import is_torch_available, is_vision_available
4039

4140
from ...generation.test_utils import GenerationTesterMixin
4241
from ...test_configuration_common import ConfigTester
@@ -477,7 +476,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
477476
test_pruning = False
478477
test_resize_embeddings = False
479478
test_attention_outputs = False
480-
test_torchscript = True
479+
test_torchscript = False
481480
_is_composite = True
482481

483482
def setUp(self):
@@ -494,116 +493,6 @@ def test_for_conditional_generation(self):
494493
config_and_inputs = self.model_tester.prepare_config_and_inputs()
495494
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
496495

497-
def _create_and_check_torchscript(self, config, inputs_dict):
498-
# overwrite because BLIP requires ipnut ids and pixel values as input
499-
if not self.test_torchscript:
500-
self.skipTest(reason="test_torchscript is set to `False`")
501-
502-
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
503-
configs_no_init.torchscript = True
504-
for model_class in self.all_model_classes:
505-
for attn_implementation in ["eager", "sdpa"]:
506-
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
507-
continue
508-
509-
configs_no_init._attn_implementation = attn_implementation
510-
model = model_class(config=configs_no_init)
511-
model.to(torch_device)
512-
model.eval()
513-
inputs = self._prepare_for_class(inputs_dict, model_class)
514-
515-
main_input_name = model_class.main_input_name
516-
517-
try:
518-
if model.config.is_encoder_decoder:
519-
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
520-
main_input = inputs[main_input_name]
521-
input_ids = inputs["input_ids"]
522-
attention_mask = inputs["attention_mask"]
523-
decoder_input_ids = inputs["decoder_input_ids"]
524-
decoder_attention_mask = inputs["decoder_attention_mask"]
525-
model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
526-
traced_model = torch.jit.trace(
527-
model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
528-
)
529-
else:
530-
main_input = inputs[main_input_name]
531-
input_ids = inputs["input_ids"]
532-
533-
if model.config._attn_implementation == "sdpa":
534-
trace_input = {main_input_name: main_input, "input_ids": input_ids}
535-
536-
if "attention_mask" in inputs:
537-
trace_input["attention_mask"] = inputs["attention_mask"]
538-
else:
539-
self.skipTest(reason="testing SDPA without attention_mask is not supported")
540-
541-
model(main_input, attention_mask=inputs["attention_mask"])
542-
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
543-
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
544-
else:
545-
model(main_input, input_ids)
546-
traced_model = torch.jit.trace(model, (main_input, input_ids))
547-
except RuntimeError:
548-
self.fail("Couldn't trace module.")
549-
550-
with tempfile.TemporaryDirectory() as tmp_dir_name:
551-
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
552-
553-
try:
554-
torch.jit.save(traced_model, pt_file_name)
555-
except Exception:
556-
self.fail("Couldn't save module.")
557-
558-
try:
559-
loaded_model = torch.jit.load(pt_file_name)
560-
except Exception:
561-
self.fail("Couldn't load module.")
562-
563-
model.to(torch_device)
564-
model.eval()
565-
566-
loaded_model.to(torch_device)
567-
loaded_model.eval()
568-
569-
model_state_dict = model.state_dict()
570-
loaded_model_state_dict = loaded_model.state_dict()
571-
572-
non_persistent_buffers = {}
573-
for key in loaded_model_state_dict.keys():
574-
if key not in model_state_dict.keys():
575-
non_persistent_buffers[key] = loaded_model_state_dict[key]
576-
577-
loaded_model_state_dict = {
578-
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
579-
}
580-
581-
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
582-
583-
model_buffers = list(model.buffers())
584-
for non_persistent_buffer in non_persistent_buffers.values():
585-
found_buffer = False
586-
for i, model_buffer in enumerate(model_buffers):
587-
if torch.equal(non_persistent_buffer, model_buffer):
588-
found_buffer = True
589-
break
590-
591-
self.assertTrue(found_buffer)
592-
model_buffers.pop(i)
593-
594-
models_equal = True
595-
for layer_name, p1 in model_state_dict.items():
596-
if layer_name in loaded_model_state_dict:
597-
p2 = loaded_model_state_dict[layer_name]
598-
if p1.data.ne(p2.data).sum() > 0:
599-
models_equal = False
600-
601-
self.assertTrue(models_equal)
602-
603-
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
604-
# (Even with this call, there are still memory leak by ~0.04MB)
605-
self.clear_torch_jit_class_registry()
606-
607496
@unittest.skip(reason="Hidden_states is tested in individual model tests")
608497
def test_hidden_states_output(self):
609498
pass
@@ -1015,7 +904,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
1015904
test_pruning = False
1016905
test_resize_embeddings = True
1017906
test_attention_outputs = False
1018-
test_torchscript = True
907+
test_torchscript = False
1019908
_is_composite = True
1020909

1021910
# TODO: Fix the failed tests
@@ -1049,116 +938,6 @@ def test_for_conditional_generation(self):
1049938
config_and_inputs = self.model_tester.prepare_config_and_inputs()
1050939
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
1051940

1052-
def _create_and_check_torchscript(self, config, inputs_dict):
1053-
# overwrite because BLIP requires ipnut ids and pixel values as input
1054-
if not self.test_torchscript:
1055-
self.skipTest(reason="test_torchscript is set to `False`")
1056-
1057-
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
1058-
configs_no_init.torchscript = True
1059-
for model_class in self.all_model_classes:
1060-
for attn_implementation in ["eager", "sdpa"]:
1061-
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
1062-
continue
1063-
1064-
configs_no_init._attn_implementation = attn_implementation
1065-
model = model_class(config=configs_no_init)
1066-
model.to(torch_device)
1067-
model.eval()
1068-
inputs = self._prepare_for_class(inputs_dict, model_class)
1069-
1070-
main_input_name = model_class.main_input_name
1071-
1072-
try:
1073-
if model.config.is_encoder_decoder:
1074-
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
1075-
main_input = inputs[main_input_name]
1076-
input_ids = inputs["input_ids"]
1077-
attention_mask = inputs["attention_mask"]
1078-
decoder_input_ids = inputs["decoder_input_ids"]
1079-
decoder_attention_mask = inputs["decoder_attention_mask"]
1080-
model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
1081-
traced_model = torch.jit.trace(
1082-
model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
1083-
)
1084-
else:
1085-
main_input = inputs[main_input_name]
1086-
input_ids = inputs["input_ids"]
1087-
1088-
if model.config._attn_implementation == "sdpa":
1089-
trace_input = {main_input_name: main_input, "input_ids": input_ids}
1090-
1091-
if "attention_mask" in inputs:
1092-
trace_input["attention_mask"] = inputs["attention_mask"]
1093-
else:
1094-
self.skipTest(reason="testing SDPA without attention_mask is not supported")
1095-
1096-
model(main_input, attention_mask=inputs["attention_mask"])
1097-
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
1098-
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
1099-
else:
1100-
model(main_input, input_ids)
1101-
traced_model = torch.jit.trace(model, (main_input, input_ids))
1102-
except RuntimeError:
1103-
self.fail("Couldn't trace module.")
1104-
1105-
with tempfile.TemporaryDirectory() as tmp_dir_name:
1106-
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
1107-
1108-
try:
1109-
torch.jit.save(traced_model, pt_file_name)
1110-
except Exception:
1111-
self.fail("Couldn't save module.")
1112-
1113-
try:
1114-
loaded_model = torch.jit.load(pt_file_name)
1115-
except Exception:
1116-
self.fail("Couldn't load module.")
1117-
1118-
model.to(torch_device)
1119-
model.eval()
1120-
1121-
loaded_model.to(torch_device)
1122-
loaded_model.eval()
1123-
1124-
model_state_dict = model.state_dict()
1125-
loaded_model_state_dict = loaded_model.state_dict()
1126-
1127-
non_persistent_buffers = {}
1128-
for key in loaded_model_state_dict.keys():
1129-
if key not in model_state_dict.keys():
1130-
non_persistent_buffers[key] = loaded_model_state_dict[key]
1131-
1132-
loaded_model_state_dict = {
1133-
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
1134-
}
1135-
1136-
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
1137-
1138-
model_buffers = list(model.buffers())
1139-
for non_persistent_buffer in non_persistent_buffers.values():
1140-
found_buffer = False
1141-
for i, model_buffer in enumerate(model_buffers):
1142-
if torch.equal(non_persistent_buffer, model_buffer):
1143-
found_buffer = True
1144-
break
1145-
1146-
self.assertTrue(found_buffer)
1147-
model_buffers.pop(i)
1148-
1149-
models_equal = True
1150-
for layer_name, p1 in model_state_dict.items():
1151-
if layer_name in loaded_model_state_dict:
1152-
p2 = loaded_model_state_dict[layer_name]
1153-
if p1.data.ne(p2.data).sum() > 0:
1154-
models_equal = False
1155-
1156-
self.assertTrue(models_equal)
1157-
1158-
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
1159-
# (Even with this call, there are still memory leak by ~0.04MB)
1160-
self.clear_torch_jit_class_registry()
1161-
1162941
@unittest.skip(reason="Hidden_states is tested in individual model tests")
1163942
def test_hidden_states_output(self):
1164943
pass

0 commit comments

Comments
 (0)