Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT/DO NOT REVIEW] Add Rotary Embedding from ONNX Opset 23 #23507

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

shubhambhokare1
Copy link
Contributor

Description

Add kernels for Rotary Embedding op that is planned to be added in ONNX opset 23.

Motivation and Context

This PR is currently experimental, and the kernels are added to a temporary location in contrib_ops. Once the following happens:

  1. Add Rotary Embedding op to ONNX opset 23 onnx/onnx#6461 is merged
  2. This op is released in the latest ONNX release
    These kernels will be migrated from contrib_ops/onnx_std_exp to the correct location.

@@ -0,0 +1,139 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@@ -0,0 +1,633 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning test

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
if (rotary_emb_dim < head_size) {
std::memcpy(output_data + rotary_emb_dim,
input_data + rotary_emb_dim,
(head_size - rotary_emb_dim) * sizeof(T));

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).
if (rotary_emb_dim < head_size) {
std::memcpy(output_data + rotary_emb_dim,
input_data + rotary_emb_dim,
(head_size - rotary_emb_dim) * sizeof(T));

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).

Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '-' to avoid overflow (io.2).
}

// Interleaved = true, pos ids shape = (1)
TEST(RotaryEmbeddingONNXTest, RotaryEmbeddingONNX_Interleaved_LargeData_LlamaMSFT) {

Check warning

Code scanning / CodeQL

Poorly documented large function Warning test

Poorly documented function: fewer than 2% comments for a function of 196 lines.
}

// Interleaved = false, pos ids shape = (1)
TEST(RotaryEmbeddingONNXTest, RotaryEmbeddingONNX_NotInterleaved_LargeData_LlamaMSFT) {

Check warning

Code scanning / CodeQL

Poorly documented large function Warning test

Poorly documented function: fewer than 2% comments for a function of 196 lines.
@@ -0,0 +1,199 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
@shubhambhokare1 shubhambhokare1 self-assigned this Jan 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant