Skip to content

set test_torchscript = False for Blip2 testing #35972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 3 additions & 224 deletions tests/models/blip_2/test_modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Testing suite for the PyTorch BLIP-2 model."""

import inspect
import os
import tempfile
import unittest

Expand All @@ -36,7 +35,7 @@
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_torch_sdpa_available, is_vision_available
from transformers.utils import is_torch_available, is_vision_available

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -478,7 +477,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
test_torchscript = True
test_torchscript = False
_is_composite = True

def setUp(self):
Expand All @@ -495,116 +494,6 @@ def test_for_conditional_generation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)

def _create_and_check_torchscript(self, config, inputs_dict):
# overwrite because BLIP requires ipnut ids and pixel values as input
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to `False`")

configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
for attn_implementation in ["eager", "sdpa"]:
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
continue

configs_no_init._attn_implementation = attn_implementation
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)

main_input_name = model_class.main_input_name

try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
main_input = inputs[main_input_name]
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
traced_model = torch.jit.trace(
model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
)
else:
main_input = inputs[main_input_name]
input_ids = inputs["input_ids"]

if model.config._attn_implementation == "sdpa":
trace_input = {main_input_name: main_input, "input_ids": input_ids}

if "attention_mask" in inputs:
trace_input["attention_mask"] = inputs["attention_mask"]
else:
self.skipTest(reason="testing SDPA without attention_mask is not supported")

model(main_input, attention_mask=inputs["attention_mask"])
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
else:
model(main_input, input_ids)
traced_model = torch.jit.trace(model, (main_input, input_ids))
except RuntimeError:
self.fail("Couldn't trace module.")

with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")

try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")

try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")

model.to(torch_device)
model.eval()

loaded_model.to(torch_device)
loaded_model.eval()

model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()

non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]

loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}

self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))

model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break

self.assertTrue(found_buffer)
model_buffers.pop(i)

models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False

self.assertTrue(models_equal)

# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()

@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
Expand Down Expand Up @@ -1010,7 +899,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
test_pruning = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = True
test_torchscript = False
_is_composite = True

# TODO: Fix the failed tests
Expand Down Expand Up @@ -1044,116 +933,6 @@ def test_for_conditional_generation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)

def _create_and_check_torchscript(self, config, inputs_dict):
# overwrite because BLIP requires ipnut ids and pixel values as input
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to `False`")

configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
for attn_implementation in ["eager", "sdpa"]:
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
continue

configs_no_init._attn_implementation = attn_implementation
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)

main_input_name = model_class.main_input_name

try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
main_input = inputs[main_input_name]
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
traced_model = torch.jit.trace(
model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
)
else:
main_input = inputs[main_input_name]
input_ids = inputs["input_ids"]

if model.config._attn_implementation == "sdpa":
trace_input = {main_input_name: main_input, "input_ids": input_ids}

if "attention_mask" in inputs:
trace_input["attention_mask"] = inputs["attention_mask"]
else:
self.skipTest(reason="testing SDPA without attention_mask is not supported")

model(main_input, attention_mask=inputs["attention_mask"])
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
else:
model(main_input, input_ids)
traced_model = torch.jit.trace(model, (main_input, input_ids))
except RuntimeError:
self.fail("Couldn't trace module.")

with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")

try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")

try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")

model.to(torch_device)
model.eval()

loaded_model.to(torch_device)
loaded_model.eval()

model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()

non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]

loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}

self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))

model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break

self.assertTrue(found_buffer)
model_buffers.pop(i)

models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False

self.assertTrue(models_equal)

# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()

@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass
Expand Down