Skip to content

Performance regression (10-15%) in LLAMA 8B and BERT training with torch-xla v2.8.0 compared to v2.8.0-rc3 #9605

@rajkthakur

Description

@rajkthakur

🐛 Bug

We are seeing about 10 - 15% reduction in performance for llama 8B and BERT training moving from torch-xla v2.8.0-rc3 to torch-xla v2.8.0. The problem was narrowed down to v2.8.0-rc3...v2.8.0, especially #9547. Building the wheels after reverting change ad76b20 has restored the performance back.

Note: This is an additional issue that we have observed after resolving the logging issue #9569 .

To Reproduce

Steps to reproduce the behavior:

  1. Install latest Neuron torch-neuronx + torch-xla + torch + torchvision, replace torch-xla and torch with 2.8.0
  2. Run https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/bert.html#hf-bert-pretraining-tutorial
  3. Compare performance against a run with v2.8.0-rc3

Expected behavior

Performance on par with torch-xla v2.8.0-rc3 on all models.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: Neuron
  • torch_xla version: 2.8.0

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions