Skip to content

[Bug] StaticCache.get_seq_length() returns shape-(1,) Tensor despite -> int contract #45987

@Abineshabee

Description

@Abineshabee

System Info

transformers : 5.7.0.dev0 (main, commit 84c2e2f)
Python : 3.13.7
Platform : Windows 11 AMD64
PyTorch : 2.11.0+cu126

Who can help?

@ArthurZucker @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

StaticCache.get_seq_length() is typed -> int in the abstract base class, but returns torch.tensor([N]) — a shape-(1,) Tensor — after the first update. DynamicCache.get_seq_length() correctly returns a plain int. This inconsistency means the two cache types are not safely interchangeable.

Reproduction (no weights download needed):

import torch
from transformers import StaticCache, DynamicCache
from transformers.models.llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel

config = LlamaConfig(
    hidden_size=64, intermediate_size=128,
    num_hidden_layers=2, num_attention_heads=2,
    num_key_value_heads=2, max_position_embeddings=512,
)
model = LlamaModel(config).eval()

cache = StaticCache(config=config, max_batch_size=1, max_cache_len=128)
input_ids = torch.randint(0, 100, (1, 8))

with torch.no_grad():
    model(input_ids, past_key_values=cache, use_cache=True)

# Type check
seq = cache.get_seq_length()
print(f"StaticCache.get_seq_length()  = {seq!r}  isinstance(int)={isinstance(seq, int)}")
print(f"shape = {seq.shape}")

# DynamicCache for comparison
dyn = DynamicCache()
with torch.no_grad():
    model(input_ids, past_key_values=dyn, use_cache=True)
print(f"DynamicCache.get_seq_length() = {dyn.get_seq_length()!r}  isinstance(int)={isinstance(dyn.get_seq_length(), int)}")

# Downstream arithmetic — return type changes
input_len = input_ids.shape[1]       # int
past_len  = cache.get_seq_length()   # tensor([8])
result    = input_len - past_len
print(f"input_len - past_len = {result!r}  type={type(result)}")

Observed output:

StaticCache.get_seq_length()  = tensor([8])  isinstance(int)=False
shape = torch.Size([1])
DynamicCache.get_seq_length() = 8   isinstance(int)=True
input_len - past_len = tensor([0])  type=<class 'torch.Tensor'>

Expected behavior

StaticCache.get_seq_length() should return a value consistent with the -> int contract declared in the abstract base class (CacheLayerMixin, cache_utils.py ), and consistent with what DynamicCache.get_seq_length() returns.

Root cause:

StaticLayer.__init__ (cache_utils.py) stores cumulative_length as a shape-(1,) tensor:

self.cumulative_length = torch.tensor([0], dtype=int)  # shape (1,)

get_seq_length() then returns it directly:

def get_seq_length(self) -> int:
    return self.cumulative_length if self.is_initialized else 0  # returns Tensor

Observed impact:

  1. Interface inconsistencyDynamicCache and StaticCache return different types from the same method. They cannot be safely swapped, which breaks the cache abstraction.

  2. Arithmetic type propagationint - tensor([N]) returns a Tensor, which can propagate into downstream slicing operations in generation code.

  3. Existing internal workaroundmasking_utils.py line 878–880 already guards against this with an explicit isinstance check, and the comment there explicitly notes "StaticLayer returns a tensor instead of int". Several other call sites in generation/utils.py and modeling files do not have this guard, which may introduce unexpected behavior.

Suggested fix:

Changing cumulative_length from a shape-(1,) tensor to a 0-dim scalar tensor would preserve compile-friendly tensor semantics while avoiding the shape-(1,) inconsistency:

- self.cumulative_length = torch.tensor([0], dtype=int)
+ self.cumulative_length = torch.tensor(0, dtype=torch.int64)

Note: this still returns a Tensor from get_seq_length(), not a true int. A scalar tensor behaves consistently with int in arithmetic contexts and avoids the need for .item() (which would force a host-device sync). Whether to fully enforce -> int via .item() or update the annotation to reflect the actual return type is a separate design decision for maintainers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions