Replies: 1 comment
-
is that because we use tensor.view(dtype) ? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
current code in memory_pool.py
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
else:
self.kv_buffer[layer_id][loc] = cache_k
there is no scale operation when cache_k.dtype != self.dtype i.e. --kv-cache-dtype fp8_e5m2 or --kv-cache-dtype fp8_e4m3
but i saw the scale operation in MHATokenToKVPool class
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
else:
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
I'm wondering no matter its MHA or MLA , if k or v needs to quant from bf16 to fp8, it always needs scale ?
I'm I missing something?
thanks
Beta Was this translation helpful? Give feedback.
All reactions