diff --git a/src/transformers/models/phi3/configuration_phi3.py b/src/transformers/models/phi3/configuration_phi3.py
index 6fe6e1cdfca8..c9bc2c1a794d 100644
--- a/src/transformers/models/phi3/configuration_phi3.py
+++ b/src/transformers/models/phi3/configuration_phi3.py
@@ -81,6 +81,8 @@ class Phi3Config(PretrainedConfig):
             contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
             the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
             divided by the number of attention heads divided by 2.
+        partial_rotary_factor (`float`, *optional*, defaults to 1.0):
+            Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0.
         bos_token_id (`int`, *optional*, defaults to 1):
             The id of the "beginning-of-sequence" token.
         eos_token_id (`int`, *optional*, defaults to 32000):
@@ -134,6 +136,7 @@ def __init__(
         tie_word_embeddings=False,
         rope_theta=10000.0,
         rope_scaling=None,
+        partial_rotary_factor=1.0,
         bos_token_id=1,
         eos_token_id=32000,
         pad_token_id=32000,
@@ -161,6 +164,7 @@ def __init__(
         self.use_cache = use_cache
         self.rope_theta = rope_theta
         self.rope_scaling = rope_scaling
+        self.partial_rotary_factor = partial_rotary_factor
         self._rope_scaling_adjustment()
         self._rope_scaling_validation()
         self.sliding_window = sliding_window
@@ -210,9 +214,10 @@ def _rope_scaling_validation(self):
             raise ValueError(
                 f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
             )
-        if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
+        rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor)
+        if not len(rope_scaling_short_factor) == rotary_ndims // 2:
             raise ValueError(
-                f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
+                f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
             )
         if not (
             isinstance(rope_scaling_long_factor, list)
@@ -221,9 +226,9 @@ def _rope_scaling_validation(self):
             raise ValueError(
                 f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
             )
-        if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
+        if not len(rope_scaling_long_factor) == rotary_ndims // 2:
             raise ValueError(
-                f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
+                f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
             )
 
 
diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py
index e86e028b4027..3ee5d8c5d0d8 100644
--- a/src/transformers/models/phi3/modeling_phi3.py
+++ b/src/transformers/models/phi3/modeling_phi3.py
@@ -82,33 +82,6 @@ def rotate_half(x):
     return torch.cat((-x2, x1), dim=-1)
 
 
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
-    """Applies Rotary Position Embedding to the query and key tensors.
-
-    Args:
-        q (`torch.Tensor`): The query tensor.
-        k (`torch.Tensor`): The key tensor.
-        cos (`torch.Tensor`): The cosine part of the rotary embedding.
-        sin (`torch.Tensor`): The sine part of the rotary embedding.
-        position_ids (`torch.Tensor`, *optional*):
-            Deprecated and unused.
-        unsqueeze_dim (`int`, *optional*, defaults to 1):
-            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
-            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
-            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
-            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
-            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
-            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
-    Returns:
-        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
-    """
-    cos = cos.unsqueeze(unsqueeze_dim)
-    sin = sin.unsqueeze(unsqueeze_dim)
-    q_embed = (q * cos) + (rotate_half(q) * sin)
-    k_embed = (k * cos) + (rotate_half(k) * sin)
-    return q_embed, k_embed
-
-
 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
     """
     This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -147,6 +120,38 @@ def eager_attention_forward(
     return attn_output, attn_weights
 
 
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+    """Applies Rotary Position Embedding to the query and key tensors.
+
+    Args:
+        q (`torch.Tensor`): The query tensor.
+        k (`torch.Tensor`): The key tensor.
+        cos (`torch.Tensor`): The cosine part of the rotary embedding.
+        sin (`torch.Tensor`): The sine part of the rotary embedding.
+        position_ids (`torch.Tensor`, *optional*):
+            Deprecated and unused.
+        unsqueeze_dim (`int`, *optional*, defaults to 1):
+            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+    Returns:
+        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+    """
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
+
+    rotary_dim = cos.shape[-1]
+    q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+    k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+    q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
+    k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
+    return q_embed, k_embed
+
+
 class Phi3Attention(nn.Module):
     """Multi-headed attention from 'Attention Is All You Need' paper"""
 
diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py
index 27f7c42f5bb8..caddb77f6fcd 100644
--- a/src/transformers/models/phi3/modular_phi3.py
+++ b/src/transformers/models/phi3/modular_phi3.py
@@ -34,8 +34,8 @@
     MistralForTokenClassification,
     MistralPreTrainedModel,
     MistralRotaryEmbedding,
-    apply_rotary_pos_emb,
     eager_attention_forward,
+    rotate_half,
 )
 from .configuration_phi3 import Phi3Config
 
@@ -64,6 +64,38 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
         return self.down_proj(up_states)
 
 
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+    """Applies Rotary Position Embedding to the query and key tensors.
+
+    Args:
+        q (`torch.Tensor`): The query tensor.
+        k (`torch.Tensor`): The key tensor.
+        cos (`torch.Tensor`): The cosine part of the rotary embedding.
+        sin (`torch.Tensor`): The sine part of the rotary embedding.
+        position_ids (`torch.Tensor`, *optional*):
+            Deprecated and unused.
+        unsqueeze_dim (`int`, *optional*, defaults to 1):
+            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+    Returns:
+        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+    """
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
+
+    rotary_dim = cos.shape[-1]
+    q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+    k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+    q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
+    k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
+    return q_embed, k_embed
+
+
 class Phi3Attention(nn.Module):
     """Multi-headed attention from 'Attention Is All You Need' paper"""