diff --git a/custom_ops/gpu_ops/swap_cache_batch.cu b/custom_ops/gpu_ops/swap_cache_batch.cu index 111a0cda99d..82cedd926d6 100644 --- a/custom_ops/gpu_ops/swap_cache_batch.cu +++ b/custom_ops/gpu_ops/swap_cache_batch.cu @@ -16,127 +16,159 @@ #include "paddle/extension.h" template -void SwapCacheImplAllLayers(const std::vector& cache_gpu_tensors, // gpu - const std::vector& cache_cpu_ptrs, // cpu - const int64_t& max_block_num_cpu, - const std::vector& swap_block_ids_gpu, - const std::vector& swap_block_ids_cpu, - int mode) { - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; - auto stream = cache_gpu_tensors[0].stream(); - for(int layer_idx=0; layer_idx < cache_gpu_tensors.size(); layer_idx++){ - const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; - const int64_t& cache_cpu_pointer = cache_cpu_ptrs[layer_idx]; - data_t* cache_gpu_ptr = const_cast(cache_gpu.data()); - auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); - auto cache_shape = cache_gpu.shape(); - const int64_t max_block_num_gpu = cache_shape[0]; - const int64_t num_heads = cache_shape[1]; - const int64_t block_size = cache_shape[2]; - const int64_t head_dim = cache_shape[3]; - const int64_t cache_stride = num_heads * block_size * head_dim; +void SwapCacheImplAllLayers( + const std::vector& cache_gpu_tensors, // gpu + const std::vector& cache_cpu_ptrs, // cpu + const int64_t& max_block_num_cpu, + const std::vector& swap_block_ids_gpu, + const std::vector& swap_block_ids_cpu, + int mode) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + auto stream = cache_gpu_tensors[0].stream(); + for (int layer_idx = 0; layer_idx < cache_gpu_tensors.size(); layer_idx++) { + const paddle::Tensor& cache_gpu = cache_gpu_tensors[layer_idx]; + const int64_t& cache_cpu_pointer = cache_cpu_ptrs[layer_idx]; + data_t* cache_gpu_ptr = const_cast(cache_gpu.data()); + auto* cache_cpu_ptr = reinterpret_cast(cache_cpu_pointer); + auto cache_shape = cache_gpu.shape(); + const int64_t max_block_num_gpu = cache_shape[0]; + const int64_t num_heads = cache_shape[1]; + const int64_t block_size = cache_shape[2]; + int64_t head_dim = 1; + if (cache_shape.size() == 4) { + head_dim = cache_shape[3]; + } + const int64_t cache_stride = num_heads * block_size * head_dim; - auto stream = cache_gpu.stream(); - if (swap_block_ids_gpu.size() == 0) { - return; - } - int i = 0; - int64_t consecutive_block_count = 1; - int64_t last_gpu_block_id = swap_block_ids_gpu[i]; - int64_t last_cpu_block_id = swap_block_ids_cpu[i]; - int64_t first_gpu_block_id = last_gpu_block_id; // first block id in a consecutive block ids - int64_t first_cpu_block_id = last_cpu_block_id; - i += 1; - while(true){ - if (i >= swap_block_ids_gpu.size()) { - break; - } - int64_t gpu_block_id = swap_block_ids_gpu[i]; - int64_t cpu_block_id = swap_block_ids_cpu[i]; - assert(gpu_block_id >= 0 && gpu_block_id < max_block_num_gpu); - assert(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu); - if (gpu_block_id == last_gpu_block_id + 1 && cpu_block_id == last_cpu_block_id + 1){ // consecutive - consecutive_block_count += 1; - last_gpu_block_id = gpu_block_id; - last_cpu_block_id = cpu_block_id; - } else{ - // end of a consecutive block ids - auto *cache_gpu_ptr_now = cache_gpu_ptr + first_gpu_block_id * cache_stride; - auto *cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride; - if (mode == 0) { // copy from device to host - cudaMemcpyAsync(cache_cpu_ptr_now, cache_gpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyDeviceToHost, stream); - } else { // copy from host to device - cudaMemcpyAsync(cache_gpu_ptr_now, cache_cpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyHostToDevice, stream); - } - first_gpu_block_id = gpu_block_id; - first_cpu_block_id = cpu_block_id; - last_gpu_block_id = gpu_block_id; - last_cpu_block_id = cpu_block_id; - consecutive_block_count = 1; - } - i += 1; - } - // last batch - auto *cache_gpu_ptr_now = cache_gpu_ptr + first_gpu_block_id * cache_stride; - auto *cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride; - if (mode == 0) { // copy from device to host - cudaMemcpyAsync(cache_cpu_ptr_now, cache_gpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyDeviceToHost, stream); - } else { // copy from host to device - cudaMemcpyAsync(cache_gpu_ptr_now, cache_cpu_ptr_now, cache_stride * sizeof(DataType_) * consecutive_block_count, cudaMemcpyHostToDevice, stream); + auto stream = cache_gpu.stream(); + if (swap_block_ids_gpu.size() == 0) { + return; + } + int i = 0; + int64_t consecutive_block_count = 1; + int64_t last_gpu_block_id = swap_block_ids_gpu[i]; + int64_t last_cpu_block_id = swap_block_ids_cpu[i]; + int64_t first_gpu_block_id = + last_gpu_block_id; // first block id in a consecutive block ids + int64_t first_cpu_block_id = last_cpu_block_id; + i += 1; + while (true) { + if (i >= swap_block_ids_gpu.size()) { + break; + } + int64_t gpu_block_id = swap_block_ids_gpu[i]; + int64_t cpu_block_id = swap_block_ids_cpu[i]; + assert(gpu_block_id >= 0 && gpu_block_id < max_block_num_gpu); + assert(cpu_block_id >= 0 && cpu_block_id < max_block_num_cpu); + if (gpu_block_id == last_gpu_block_id + 1 && + cpu_block_id == last_cpu_block_id + 1) { // consecutive + consecutive_block_count += 1; + last_gpu_block_id = gpu_block_id; + last_cpu_block_id = cpu_block_id; + } else { + // end of a consecutive block ids + auto* cache_gpu_ptr_now = + cache_gpu_ptr + first_gpu_block_id * cache_stride; + auto* cache_cpu_ptr_now = + cache_cpu_ptr + first_cpu_block_id * cache_stride; + if (mode == 0) { // copy from device to host + cudaMemcpyAsync( + cache_cpu_ptr_now, + cache_gpu_ptr_now, + cache_stride * sizeof(DataType_) * consecutive_block_count, + cudaMemcpyDeviceToHost, + stream); + } else { // copy from host to device + cudaMemcpyAsync( + cache_gpu_ptr_now, + cache_cpu_ptr_now, + cache_stride * sizeof(DataType_) * consecutive_block_count, + cudaMemcpyHostToDevice, + stream); } + first_gpu_block_id = gpu_block_id; + first_cpu_block_id = cpu_block_id; + last_gpu_block_id = gpu_block_id; + last_cpu_block_id = cpu_block_id; + consecutive_block_count = 1; + } + i += 1; + } + // last batch + auto* cache_gpu_ptr_now = cache_gpu_ptr + first_gpu_block_id * cache_stride; + auto* cache_cpu_ptr_now = cache_cpu_ptr + first_cpu_block_id * cache_stride; + if (mode == 0) { // copy from device to host + cudaMemcpyAsync( + cache_cpu_ptr_now, + cache_gpu_ptr_now, + cache_stride * sizeof(DataType_) * consecutive_block_count, + cudaMemcpyDeviceToHost, + stream); + } else { // copy from host to device + cudaMemcpyAsync( + cache_gpu_ptr_now, + cache_cpu_ptr_now, + cache_stride * sizeof(DataType_) * consecutive_block_count, + cudaMemcpyHostToDevice, + stream); } - cudaStreamSynchronize(stream); + } + cudaStreamSynchronize(stream); } -void SwapCacheAllLayers(const std::vector& cache_gpu_tensors, // gpu - const std::vector& cache_cpu_ptrs, // cpu memory pointer - int64_t max_block_num_cpu, // cpu max block num - const std::vector& swap_block_ids_gpu, - const std::vector& swap_block_ids_cpu, - int rank, - int mode) { - cudaSetDevice(rank); // used for distributed launch - assert(cache_gpu_tensors.size() > 0 && cache_gpu_tensors.size() == cache_cpu_ptrs.size()); - switch (cache_gpu_tensors[0].dtype()) { - case paddle::DataType::BFLOAT16: - return SwapCacheImplAllLayers( - cache_gpu_tensors, - cache_cpu_ptrs, - max_block_num_cpu, - swap_block_ids_gpu, - swap_block_ids_cpu, - mode); - case paddle::DataType::FLOAT16: - return SwapCacheImplAllLayers( - cache_gpu_tensors, - cache_cpu_ptrs, - max_block_num_cpu, - swap_block_ids_gpu, - swap_block_ids_cpu, - mode); - case paddle::DataType::UINT8: - return SwapCacheImplAllLayers( - cache_gpu_tensors, - cache_cpu_ptrs, - max_block_num_cpu, - swap_block_ids_gpu, - swap_block_ids_cpu, - mode); - default: - PD_THROW("Unsupported data type."); - } +void SwapCacheAllLayers( + const std::vector& cache_gpu_tensors, // gpu + const std::vector& cache_cpu_ptrs, // cpu memory pointer + int64_t max_block_num_cpu, // cpu max block num + const std::vector& swap_block_ids_gpu, + const std::vector& swap_block_ids_cpu, + int rank, + int mode) { + cudaSetDevice(rank); // used for distributed launch + assert(cache_gpu_tensors.size() > 0 && + cache_gpu_tensors.size() == cache_cpu_ptrs.size()); + switch (cache_gpu_tensors[0].dtype()) { + case paddle::DataType::BFLOAT16: + return SwapCacheImplAllLayers( + cache_gpu_tensors, + cache_cpu_ptrs, + max_block_num_cpu, + swap_block_ids_gpu, + swap_block_ids_cpu, + mode); + case paddle::DataType::FLOAT16: + return SwapCacheImplAllLayers( + cache_gpu_tensors, + cache_cpu_ptrs, + max_block_num_cpu, + swap_block_ids_gpu, + swap_block_ids_cpu, + mode); + case paddle::DataType::UINT8: + return SwapCacheImplAllLayers(cache_gpu_tensors, + cache_cpu_ptrs, + max_block_num_cpu, + swap_block_ids_gpu, + swap_block_ids_cpu, + mode); + default: + PD_THROW("Unsupported data type."); + } } PD_BUILD_STATIC_OP(swap_cache_all_layers) .Inputs({paddle::Vec("cache_gpu_tensors")}) - .Attrs({"cache_cpu_ptrs: std::vector", - "max_block_num_cpu: int64_t", - "swap_block_ids_gpu: std::vector", - "swap_block_ids_cpu: std::vector", - "rank: int", - "mode: int",}) + .Attrs({ + "cache_cpu_ptrs: std::vector", + "max_block_num_cpu: int64_t", + "swap_block_ids_gpu: std::vector", + "swap_block_ids_cpu: std::vector", + "rank: int", + "mode: int", + }) .Outputs({paddle::Vec("cache_dst_outs")}) - .SetInplaceMap({{paddle::Vec("cache_gpu_tensors"), paddle::Vec("cache_dst_outs")}}) + .SetInplaceMap({{paddle::Vec("cache_gpu_tensors"), + paddle::Vec("cache_dst_outs")}}) .SetKernelFn(PD_KERNEL(SwapCacheAllLayers)); diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index c9b6cd83f5b..6a67013fcf7 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -63,7 +63,7 @@ def parse_args(): "--cache_dtype", type=str, default="bfloat16", - choices=["uint8", "bfloat16"], + choices=["uint8", "bfloat16", "block_wise_fp8"], help="cache dtype", ) parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape") @@ -114,6 +114,8 @@ def __init__(self, args): self.cpu_cache_kvs = {} self.gpu_cache_k_tensors = [] self.gpu_cache_v_tensors = [] + self.gpu_cache_scales_k_tensors = [] + self.gpu_cache_scales_v_tensors = [] self.speculative_config = SpeculativeConfig(args.speculative_config) self.key_cache_shape = [int(i) for i in args.key_cache_shape.split(",")] self.value_cache_shape = [] @@ -131,6 +133,7 @@ def __init__(self, args): self.rank = rank self.device = device self.engine_pid = args.engine_pid + self.cache_dtype = args.cache_dtype address = (args.pod_ip, args.cache_queue_port) self.cache_task_queue = EngineCacheQueue( @@ -203,12 +206,19 @@ def _init_gpu_cache(self, args): time.sleep(0.1) logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.") + if args.cache_dtype == "block_wise_fp8": + cache_type = "uint8" + else: + cache_type = args.cache_dtype + logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.") set_device(self.device) for i in range(args.num_layers + self.num_extra_layers): num_gpu_blocks = self.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}" val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}" + key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}.device{self.device}" + value_cache_scales_name = f"value_cache_scales_{i}_rank{self.rank}.device{self.device}" key_cache_shape = [ num_gpu_blocks, self.key_cache_shape[1], @@ -227,26 +237,64 @@ def _init_gpu_cache(self, args): logger.info( f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {key_cache_shape} {value_cache_shape}" ) - key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=args.cache_dtype) + key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(key_cache, key_name) + + if args.cache_dtype == "block_wise_fp8": + key_cache_scales = paddle.full( + shape=[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]], + fill_value=0, + dtype=paddle.get_default_dtype(), + ) + set_data_ipc(key_cache_scales, key_cache_scales_name) if self.value_cache_shape: - val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=args.cache_dtype) + val_cache = paddle.full(shape=value_cache_shape, fill_value=0, dtype=cache_type) set_data_ipc(val_cache, val_name) + + if args.cache_dtype == "block_wise_fp8": + value_cache_scales = paddle.full( + shape=[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]], + fill_value=0, + dtype=paddle.get_default_dtype(), + ) + set_data_ipc(value_cache_scales, value_cache_scales_name) else: logger.info( f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {key_cache_shape} {value_cache_shape}" ) - key_cache = paddle.empty(shape=[], dtype=args.cache_dtype) - val_cache = paddle.empty(shape=[], dtype=args.cache_dtype) + key_cache = paddle.empty(shape=[], dtype=cache_type) + val_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True) + if args.cache_dtype == "block_wise_fp8": + key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + key_cache_scales = share_external_data_( + key_cache_scales, + key_cache_scales_name, + [num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]], + True, + ) if self.value_cache_shape: val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True) + if args.cache_dtype == "block_wise_fp8": + value_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + value_cache_scales = share_external_data_( + value_cache_scales, + value_cache_scales_name, + [num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]], + True, + ) self.gpu_cache_kvs[key_name] = key_cache self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name]) + if args.cache_dtype == "block_wise_fp8": + self.gpu_cache_kvs[key_cache_scales_name] = key_cache_scales + self.gpu_cache_scales_k_tensors.append(self.gpu_cache_kvs[key_cache_scales_name]) if args.value_cache_shape: self.gpu_cache_kvs[val_name] = val_cache self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name]) + if args.cache_dtype == "block_wise_fp8": + self.gpu_cache_kvs[value_cache_scales_name] = value_cache_scales + self.gpu_cache_scales_v_tensors.append(self.gpu_cache_kvs[value_cache_scales_name]) if args.create_cache_tensor: logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!") @@ -265,12 +313,17 @@ def _init_cpu_cache(self, args): value_cache_size = 0 if args.cache_dtype == "bfloat16": cache_bytes = 2 - elif args.cache_dtype == "uint8": + elif args.cache_dtype == "uint8" or args.cache_dtype == "block_wise_fp8": cache_bytes = 1 else: raise ValueError(f"Unsupported cache dtype: {args.cache_dtype}") key_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * key_cache_size value_need_to_allocate_bytes = args.num_cpu_blocks * cache_bytes * value_cache_size + if args.cache_dtype == "block_wise_fp8": + cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2] + scales_key_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size + scales_value_need_to_allocate_bytes = args.num_cpu_blocks * cache_scales.element_size() * cache_scales_size logger.info( f"[rank {self.rank}/{self.n_ranks}] ..swap space size : {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB" ) @@ -282,17 +335,27 @@ def _init_cpu_cache(self, args): paddle.set_device("cpu") self.k_dst_ptrs = [] self.v_dst_ptrs = [] + self.k_scales_ptrs = [] + self.v_scales_ptrs = [] for i in range(args.num_layers + self.num_extra_layers): key_name = f"key_caches_{i}_rank{self.rank}" val_name = f"value_caches_{i}_rank{self.rank}" + key_cache_scales_name = f"key_cache_scales_{i}_rank{self.rank}" + value_cache_scales_name = f"value_cache_scales_{i}_rank{self.rank}" logger.info( f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {(key_need_to_allocate_bytes + value_need_to_allocate_bytes) / 1024 ** 3:.2f}GB" ) self.cpu_cache_kvs[key_name] = cuda_host_alloc(key_need_to_allocate_bytes) self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name]) + if args.cache_dtype == "block_wise_fp8": + self.cpu_cache_kvs[key_cache_scales_name] = cuda_host_alloc(scales_key_need_to_allocate_bytes) + self.k_scales_ptrs.append(self.cpu_cache_kvs[key_cache_scales_name]) if value_need_to_allocate_bytes > 0: self.cpu_cache_kvs[val_name] = cuda_host_alloc(value_need_to_allocate_bytes) self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name]) + if args.cache_dtype == "block_wise_fp8": + self.cpu_cache_kvs[value_cache_scales_name] = cuda_host_alloc(scales_value_need_to_allocate_bytes) + self.v_scales_ptrs.append(self.cpu_cache_kvs[value_cache_scales_name]) logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!") self.swap_space_ready_signal.value[self.rank] = 1 @@ -492,6 +555,25 @@ def _transfer_data( self.device, 0, ) + if self.cache_dtype == "block_wise_fp8": + swap_cache_all_layers( + self.gpu_cache_scales_k_tensors, + self.k_scales_ptrs, + self.num_cpu_blocks, + gpu_block_ids, + cpu_block_ids, + self.device, + 0, + ) + swap_cache_all_layers( + self.gpu_cache_scales_v_tensors, + self.v_scales_ptrs, + self.num_cpu_blocks, + gpu_block_ids, + cpu_block_ids, + self.device, + 0, + ) elif event_type.value == CacheStatus.SWAP2GPU.value: swap_cache_all_layers( @@ -512,6 +594,25 @@ def _transfer_data( self.device, 1, ) + if self.cache_dtype == "block_wise_fp8": + swap_cache_all_layers( + self.gpu_cache_scales_k_tensors, + self.k_scales_ptrs, + self.num_cpu_blocks, + gpu_block_ids, + cpu_block_ids, + self.device, + 1, + ) + swap_cache_all_layers( + self.gpu_cache_scales_v_tensors, + self.v_scales_ptrs, + self.num_cpu_blocks, + gpu_block_ids, + cpu_block_ids, + self.device, + 1, + ) else: logger.warning( f"transfer data: Get unexpected event type {event_type}, only SWAP2CPU and SWAP2GPU supported" diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 5faad79f107..91d2bae64a5 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1239,6 +1239,8 @@ def __init__(self, args): self.enable_hierarchical_cache = True if self.model_cfg is not None: + if self.model_cfg.quantization is not None and isinstance(self.model_cfg.quantization, dict): + self.cache_dtype = self.model_cfg.quantization.get("kv_cache_quant_type", self.cache_dtype) if self.model_cfg.quantization_config is not None: self.cache_dtype = self.model_cfg.quantization_config.get("kv_cache_quant_type", self.cache_dtype) if ( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index edbf67d4bda..a63876d4afb 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1450,8 +1450,10 @@ def initialize_kv_cache(self, profile: bool = False) -> None: for i in range(self.model_config.num_hidden_layers): # init key cache key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" + key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device}" if value_cache_shape: val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" + value_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device}" if create_cache_tensor: logger.info(f"..creating kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") key_cache = paddle.full(shape=key_cache_shape, fill_value=0, dtype=cache_type) @@ -1477,12 +1479,25 @@ def initialize_kv_cache(self, profile: bool = False) -> None: logger.info(f"..attaching kv cache for layer {i}: key:{key_cache_shape}, value:{value_cache_shape}") key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape) + if kv_cache_quant_type == "block_wise_fp8": + key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + key_cache_scales = share_external_data( + key_cache_scales, key_cache_scales_name, kv_cache_scale_shape + ) if value_cache_shape: val_cache = paddle.empty(shape=[], dtype=cache_type) val_cache = share_external_data(val_cache, val_cache_name, value_cache_shape) cache_kvs_list.extend([key_cache, val_cache]) + if kv_cache_quant_type == "block_wise_fp8": + val_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + val_cache_scales = share_external_data( + val_cache_scales, value_cache_scales_name, kv_cache_scale_shape + ) + cache_kvs_list.extend([key_cache_scales, val_cache_scales]) else: cache_kvs_list.extend([key_cache]) + if kv_cache_quant_type == "block_wise_fp8": + cache_kvs_list.extend([key_cache_scales]) self.share_inputs["caches"] = cache_kvs_list diff --git a/tests/cache_manager/test_cache_transfer_manager.py b/tests/cache_manager/test_cache_transfer_manager.py index 96f0b2ada26..f09fc603325 100644 --- a/tests/cache_manager/test_cache_transfer_manager.py +++ b/tests/cache_manager/test_cache_transfer_manager.py @@ -25,6 +25,7 @@ class Args: key_cache_shape = "1,1,1,1" value_cache_shape = "" create_cache_tensor = False + cache_dtype = "bfloat16" # ==========================