From 9ca7ba8c2ed87215c29fbf811f9eb98f16d58030 Mon Sep 17 00:00:00 2001 From: ariG23498 Date: Tue, 6 May 2025 13:18:58 +0200 Subject: [PATCH] fix: attention output --- .../decision_transformer/modeling_decision_transformer.py | 3 ++- src/transformers/models/gpt2/modeling_gpt2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index ab2a3024052b..953f23200dce 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -273,6 +273,7 @@ def forward( output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + input_shape = hidden_states.shape[:-1] is_cross_attention = encoder_hidden_states is not None if is_cross_attention: if not hasattr(self, "q_attn"): @@ -339,7 +340,7 @@ def forward( **kwargs, ) - attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 0af4d9906552..1b5f60298c37 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -283,6 +283,7 @@ def forward( output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + input_shape = hidden_states.shape[:-1] is_cross_attention = encoder_hidden_states is not None if is_cross_attention: if not hasattr(self, "q_attn"): @@ -349,7 +350,7 @@ def forward( **kwargs, ) - attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output)