Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accuracy Issue for Sharded Llama #19948

Open
stbaione opened this issue Feb 10, 2025 · 3 comments
Open

Accuracy Issue for Sharded Llama #19948

stbaione opened this issue Feb 10, 2025 · 3 comments
Labels
bug 🐞 Something isn't working

Comments

@stbaione
Copy link

stbaione commented Feb 10, 2025

What happened?

The shortfin server is showing corrupt outputs at HEAD of IREE when running llama3.1_8b_tp8 nod-ai/shark-ai#934:

curl http://localhost:8081/generate     -H "Content-Type: application/json"     -d '{
        "text": "Name the capital of the United States.",
        "sampling_params": {"max_completion_tokens": 50}
    }'
data:  Washington!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! !.  I! !!!!!!!
.NameTheNameThecapital ofthe capital ofthe!!!!.!.!

I was able to bisect this to 4fffb0e.

I'm not sure of a good way to reproduce this or provide good signal on the IREE side.

I'll include steps to invoke iree-run-module while I investigate this, in case you're already aware of how to do it. Otherwise, let me know if there's a better method of reproduction I could do for you.

Steps to reproduce your issue

  1. You can use the irpa files located at /data/llama3.1/weights/8b/fp16/tp8 on mi300x-3 or /shark_dev/data/llama3.1/weights/8b/fp16/tp8 on mi300x
  2. Download the MLIR
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/mlir/iree_issue_corrupt_sharded/llama3.1_8b_tp8.mlir
  1. Download the inputs
mkdir 8b_short_inputs
cd 8b_short_inputs

wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/tokens.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/seq_ids.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/seq_block_ids.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/cache_state_shard_0.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/cache_state_shard_1.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/cache_state_shard_2.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/cache_state_shard_3.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/cache_state_shard_4.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/cache_state_shard_5.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/cache_state_shard_6.npy
wget https://sharkpublic.blob.core.windows.net/sharkpublic/stephen/llama3.1_8b_tp8/inputs/prefill/iree_issue_corrupt_shards/cache_state_shard_7.npy
  1. Compile
iree-compile llama3.1_8b_tp8.mlir \
  -o llama3.1_8b_tp8.vmfb \
  --iree-hal-target-device=hip[0] \
  --iree-hal-target-device=hip[1] \
  --iree-hal-target-device=hip[2] \
  --iree-hal-target-device=hip[3] \
  --iree-hal-target-device=hip[4] \
  --iree-hal-target-device=hip[5] \
  --iree-hal-target-device=hip[6] \
  --iree-hal-target-device=hip[7] \
  --iree-hip-target=gfx942 \
  --iree-dispatch-creation-enable-aggressive-fusion=true \
  --iree-global-opt-propagate-transposes=true \
  --iree-opt-aggressively-propagate-transposes=true \
  --iree-opt-data-tiling=false \
  --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
  --iree-hal-indirect-command-buffers=true \
  --iree-stream-resource-memory-model=discrete \
  --iree-hal-memoization=true \
  --iree-opt-strip-assertions
  1. Invoke iree-run-module
ROCR_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
iree-run-module \
  --hip_use_streams=true \
  --module=llama3.1_8b_tp8.vmfb \
  --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.irpa \
  --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank0.irpa \
  --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank1.irpa \
  --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank2.irpa \
  --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank3.irpa \
  --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank4.irpa \
  --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank5.irpa \
  --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank6.irpa \
  --parameters=model=/data/llama3.1/weights/8b/fp16/tp8/llama3.1_8b_fp16_tp8_parameters.rank7.irpa \
  --device=hip://0 \
  --device=hip://1 \
  --device=hip://2 \
  --device=hip://3 \
  --device=hip://4 \
  --device=hip://5 \
  --device=hip://6 \
  --device=hip://7 \
  --function=prefill_bs4 \
  --input=@8b_short_inputs/tokens.npy \
  --input=@8b_short_inputs/seq_ids.npy \
  --input=@8b_short_inputs/seq_block_ids.npy \
  --input=@8b_short_inputs/cache_state_shard_0.npy \
  --input=@8b_short_inputs/cache_state_shard_1.npy \
  --input=@8b_short_inputs/cache_state_shard_2.npy \
  --input=@8b_short_inputs/cache_state_shard_3.npy \
  --input=@8b_short_inputs/cache_state_shard_4.npy \
  --input=@8b_short_inputs/cache_state_shard_5.npy \
  --input=@8b_short_inputs/cache_state_shard_6.npy \
  --input=@8b_short_inputs/cache_state_shard_7.npy

What component(s) does this issue relate to?

No response

Version information

0781072

Additional context

No response

@stbaione stbaione added the bug 🐞 Something isn't working label Feb 10, 2025
stbaione added a commit to nod-ai/shark-ai that referenced this issue Feb 10, 2025
Add a warning for sharded llama accuracy until
iree-org/iree#19948 is resolved.
@stbaione stbaione changed the title NaNs in KVCache for Sharded Llama Accuracy Issue for Sharded Llama Feb 11, 2025
stbaione added a commit to stbaione/iree that referenced this issue Feb 11, 2025
This reverts the changes to `SchedulingExecution.cpp` from commit 4fffb0e.

This change caused corrupt tokens to be outputted from sharded llama models (iree-org#19948)
@stbaione
Copy link
Author

Collected two new traces from the shortfin server. One is with the specified commit, and has corrupt tokens. One is with the commit reverted and had good token output. Not sure if there are any insights that can be gleaned from this:

Good Output

trace

Prompt - 0:
<|begin_of_text|>Name the capital of the United States.<|eot_id|>
Response:
data: assistant
The capital of the United States is Washington, D.C.


--------------------------------------------------

Bad Output

trace

Prompt - 0:
<|begin_of_text|>Name the capital of the United States.<|eot_id|>
Response:
data: assistant
://://://://://://://://://://://://://_REF
I
I
I


--------------------------------------------------

@stbaione
Copy link
Author

Expanding on:

Prompt - 0:
<|begin_of_text|>Name the capital of the United States.<|eot_id|>
Response:
data: assistant
://://://://://://://://://://://://://_REF
I
I
I


--------------------------------------------------

The bad tokens start to happen during the 3rd decode invocation. The outputs in-order are:

[128006, 78191, 198, 1129, 1129, 1129, 1129, ..., 12592, 198, 40, 198, 40, 128009]

Prefill: 128006 - Empty String (good)
Decode 1: 78191 - assistant (good)
Decode 2: 198 - \n (good)
Decode 3: 1129 - :// (bad)

From there is repeats for awhile, until it hits the tokens at the end of the list above, which are also nonsensical.

Output from good tokens were:

[128006, 78191, 198, 791, 6864, ...]

Good tokens had no repetitions are discernable patterns in the output. As you can see at the 3rd decode step, the outputs start to differ (1129 vs 791)

monorimet pushed a commit to nod-ai/shark-ai that referenced this issue Feb 13, 2025
Add a warning for sharded llama accuracy until
iree-org/iree#19948 is resolved.
@stbaione
Copy link
Author

Compiled to stream for both good and bad outputs. Noting here in case it contains anything valuable towards the issue:

Good
Bad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant