1515"""Testing suite for the PyTorch BLIP-2 model."""
1616
1717import inspect
18- import os
1918import tempfile
2019import unittest
2120
3635 slow ,
3736 torch_device ,
3837)
39- from transformers .utils import is_torch_available , is_torch_sdpa_available , is_vision_available
38+ from transformers .utils import is_torch_available , is_vision_available
4039
4140from ...generation .test_utils import GenerationTesterMixin
4241from ...test_configuration_common import ConfigTester
@@ -477,7 +476,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
477476 test_pruning = False
478477 test_resize_embeddings = False
479478 test_attention_outputs = False
480- test_torchscript = True
479+ test_torchscript = False
481480 _is_composite = True
482481
483482 def setUp (self ):
@@ -494,116 +493,6 @@ def test_for_conditional_generation(self):
494493 config_and_inputs = self .model_tester .prepare_config_and_inputs ()
495494 self .model_tester .create_and_check_for_conditional_generation (* config_and_inputs )
496495
497- def _create_and_check_torchscript (self , config , inputs_dict ):
498- # overwrite because BLIP requires ipnut ids and pixel values as input
499- if not self .test_torchscript :
500- self .skipTest (reason = "test_torchscript is set to `False`" )
501-
502- configs_no_init = _config_zero_init (config ) # To be sure we have no Nan
503- configs_no_init .torchscript = True
504- for model_class in self .all_model_classes :
505- for attn_implementation in ["eager" , "sdpa" ]:
506- if attn_implementation == "sdpa" and (not model_class ._supports_sdpa or not is_torch_sdpa_available ()):
507- continue
508-
509- configs_no_init ._attn_implementation = attn_implementation
510- model = model_class (config = configs_no_init )
511- model .to (torch_device )
512- model .eval ()
513- inputs = self ._prepare_for_class (inputs_dict , model_class )
514-
515- main_input_name = model_class .main_input_name
516-
517- try :
518- if model .config .is_encoder_decoder :
519- model .config .use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
520- main_input = inputs [main_input_name ]
521- input_ids = inputs ["input_ids" ]
522- attention_mask = inputs ["attention_mask" ]
523- decoder_input_ids = inputs ["decoder_input_ids" ]
524- decoder_attention_mask = inputs ["decoder_attention_mask" ]
525- model (main_input , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask )
526- traced_model = torch .jit .trace (
527- model , (main_input , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask )
528- )
529- else :
530- main_input = inputs [main_input_name ]
531- input_ids = inputs ["input_ids" ]
532-
533- if model .config ._attn_implementation == "sdpa" :
534- trace_input = {main_input_name : main_input , "input_ids" : input_ids }
535-
536- if "attention_mask" in inputs :
537- trace_input ["attention_mask" ] = inputs ["attention_mask" ]
538- else :
539- self .skipTest (reason = "testing SDPA without attention_mask is not supported" )
540-
541- model (main_input , attention_mask = inputs ["attention_mask" ])
542- # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
543- traced_model = torch .jit .trace (model , example_kwarg_inputs = trace_input )
544- else :
545- model (main_input , input_ids )
546- traced_model = torch .jit .trace (model , (main_input , input_ids ))
547- except RuntimeError :
548- self .fail ("Couldn't trace module." )
549-
550- with tempfile .TemporaryDirectory () as tmp_dir_name :
551- pt_file_name = os .path .join (tmp_dir_name , "traced_model.pt" )
552-
553- try :
554- torch .jit .save (traced_model , pt_file_name )
555- except Exception :
556- self .fail ("Couldn't save module." )
557-
558- try :
559- loaded_model = torch .jit .load (pt_file_name )
560- except Exception :
561- self .fail ("Couldn't load module." )
562-
563- model .to (torch_device )
564- model .eval ()
565-
566- loaded_model .to (torch_device )
567- loaded_model .eval ()
568-
569- model_state_dict = model .state_dict ()
570- loaded_model_state_dict = loaded_model .state_dict ()
571-
572- non_persistent_buffers = {}
573- for key in loaded_model_state_dict .keys ():
574- if key not in model_state_dict .keys ():
575- non_persistent_buffers [key ] = loaded_model_state_dict [key ]
576-
577- loaded_model_state_dict = {
578- key : value for key , value in loaded_model_state_dict .items () if key not in non_persistent_buffers
579- }
580-
581- self .assertEqual (set (model_state_dict .keys ()), set (loaded_model_state_dict .keys ()))
582-
583- model_buffers = list (model .buffers ())
584- for non_persistent_buffer in non_persistent_buffers .values ():
585- found_buffer = False
586- for i , model_buffer in enumerate (model_buffers ):
587- if torch .equal (non_persistent_buffer , model_buffer ):
588- found_buffer = True
589- break
590-
591- self .assertTrue (found_buffer )
592- model_buffers .pop (i )
593-
594- models_equal = True
595- for layer_name , p1 in model_state_dict .items ():
596- if layer_name in loaded_model_state_dict :
597- p2 = loaded_model_state_dict [layer_name ]
598- if p1 .data .ne (p2 .data ).sum () > 0 :
599- models_equal = False
600-
601- self .assertTrue (models_equal )
602-
603- # Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
604- # (Even with this call, there are still memory leak by ~0.04MB)
605- self .clear_torch_jit_class_registry ()
606-
607496 @unittest .skip (reason = "Hidden_states is tested in individual model tests" )
608497 def test_hidden_states_output (self ):
609498 pass
@@ -1015,7 +904,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
1015904 test_pruning = False
1016905 test_resize_embeddings = True
1017906 test_attention_outputs = False
1018- test_torchscript = True
907+ test_torchscript = False
1019908 _is_composite = True
1020909
1021910 # TODO: Fix the failed tests
@@ -1049,116 +938,6 @@ def test_for_conditional_generation(self):
1049938 config_and_inputs = self .model_tester .prepare_config_and_inputs ()
1050939 self .model_tester .create_and_check_for_conditional_generation (* config_and_inputs )
1051940
1052- def _create_and_check_torchscript (self , config , inputs_dict ):
1053- # overwrite because BLIP requires ipnut ids and pixel values as input
1054- if not self .test_torchscript :
1055- self .skipTest (reason = "test_torchscript is set to `False`" )
1056-
1057- configs_no_init = _config_zero_init (config ) # To be sure we have no Nan
1058- configs_no_init .torchscript = True
1059- for model_class in self .all_model_classes :
1060- for attn_implementation in ["eager" , "sdpa" ]:
1061- if attn_implementation == "sdpa" and (not model_class ._supports_sdpa or not is_torch_sdpa_available ()):
1062- continue
1063-
1064- configs_no_init ._attn_implementation = attn_implementation
1065- model = model_class (config = configs_no_init )
1066- model .to (torch_device )
1067- model .eval ()
1068- inputs = self ._prepare_for_class (inputs_dict , model_class )
1069-
1070- main_input_name = model_class .main_input_name
1071-
1072- try :
1073- if model .config .is_encoder_decoder :
1074- model .config .use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
1075- main_input = inputs [main_input_name ]
1076- input_ids = inputs ["input_ids" ]
1077- attention_mask = inputs ["attention_mask" ]
1078- decoder_input_ids = inputs ["decoder_input_ids" ]
1079- decoder_attention_mask = inputs ["decoder_attention_mask" ]
1080- model (main_input , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask )
1081- traced_model = torch .jit .trace (
1082- model , (main_input , input_ids , attention_mask , decoder_input_ids , decoder_attention_mask )
1083- )
1084- else :
1085- main_input = inputs [main_input_name ]
1086- input_ids = inputs ["input_ids" ]
1087-
1088- if model .config ._attn_implementation == "sdpa" :
1089- trace_input = {main_input_name : main_input , "input_ids" : input_ids }
1090-
1091- if "attention_mask" in inputs :
1092- trace_input ["attention_mask" ] = inputs ["attention_mask" ]
1093- else :
1094- self .skipTest (reason = "testing SDPA without attention_mask is not supported" )
1095-
1096- model (main_input , attention_mask = inputs ["attention_mask" ])
1097- # example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
1098- traced_model = torch .jit .trace (model , example_kwarg_inputs = trace_input )
1099- else :
1100- model (main_input , input_ids )
1101- traced_model = torch .jit .trace (model , (main_input , input_ids ))
1102- except RuntimeError :
1103- self .fail ("Couldn't trace module." )
1104-
1105- with tempfile .TemporaryDirectory () as tmp_dir_name :
1106- pt_file_name = os .path .join (tmp_dir_name , "traced_model.pt" )
1107-
1108- try :
1109- torch .jit .save (traced_model , pt_file_name )
1110- except Exception :
1111- self .fail ("Couldn't save module." )
1112-
1113- try :
1114- loaded_model = torch .jit .load (pt_file_name )
1115- except Exception :
1116- self .fail ("Couldn't load module." )
1117-
1118- model .to (torch_device )
1119- model .eval ()
1120-
1121- loaded_model .to (torch_device )
1122- loaded_model .eval ()
1123-
1124- model_state_dict = model .state_dict ()
1125- loaded_model_state_dict = loaded_model .state_dict ()
1126-
1127- non_persistent_buffers = {}
1128- for key in loaded_model_state_dict .keys ():
1129- if key not in model_state_dict .keys ():
1130- non_persistent_buffers [key ] = loaded_model_state_dict [key ]
1131-
1132- loaded_model_state_dict = {
1133- key : value for key , value in loaded_model_state_dict .items () if key not in non_persistent_buffers
1134- }
1135-
1136- self .assertEqual (set (model_state_dict .keys ()), set (loaded_model_state_dict .keys ()))
1137-
1138- model_buffers = list (model .buffers ())
1139- for non_persistent_buffer in non_persistent_buffers .values ():
1140- found_buffer = False
1141- for i , model_buffer in enumerate (model_buffers ):
1142- if torch .equal (non_persistent_buffer , model_buffer ):
1143- found_buffer = True
1144- break
1145-
1146- self .assertTrue (found_buffer )
1147- model_buffers .pop (i )
1148-
1149- models_equal = True
1150- for layer_name , p1 in model_state_dict .items ():
1151- if layer_name in loaded_model_state_dict :
1152- p2 = loaded_model_state_dict [layer_name ]
1153- if p1 .data .ne (p2 .data ).sum () > 0 :
1154- models_equal = False
1155-
1156- self .assertTrue (models_equal )
1157-
1158- # Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
1159- # (Even with this call, there are still memory leak by ~0.04MB)
1160- self .clear_torch_jit_class_registry ()
1161-
1162941 @unittest .skip (reason = "Hidden_states is tested in individual model tests" )
1163942 def test_hidden_states_output (self ):
1164943 pass
0 commit comments