1717
1818#include < algorithm>
1919#include < cassert>
20- #include < cmath>
2120#include < cfloat>
2221#include < cstring>
2322#include < cmath>
@@ -440,7 +439,7 @@ struct llama_model::impl {
440439 llama_mlocks mlock_mmaps;
441440
442441 // contexts where the model tensors metadata is stored as well ass the corresponding buffers:
443- std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
442+ std::vector<std::pair<ggml_context_ptr, std::vector< ggml_backend_buffer_ptr> >> ctxs_bufs;
444443
445444 buft_list_t cpu_buft_list;
446445 std::map<ggml_backend_dev_t , buft_list_t > gpu_buft_list;
@@ -6188,7 +6187,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
61886187 bool buffer_from_host_ptr_supported = props.caps .buffer_from_host_ptr ;
61896188 bool is_default_buft = buft == ggml_backend_dev_buffer_type (dev);
61906189
6191- ggml_backend_buffer_t buf = nullptr ;
6190+ std::vector<ggml_backend_buffer_ptr> bufs ;
61926191 if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
61936192 for (uint32_t idx = 0 ; idx < ml.files .size (); idx++) {
61946193 // only the mmap region containing the tensors in the model is mapped to the backend buffer
@@ -6201,15 +6200,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
62016200 continue ;
62026201 }
62036202 const size_t max_size = ggml_get_max_tensor_size (ctx);
6204- buf = ggml_backend_dev_buffer_from_host_ptr (dev, (char *) addr + first, last - first, max_size);
6203+ ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr (dev, (char *) addr + first, last - first, max_size);
62056204 if (buf == nullptr ) {
62066205 throw std::runtime_error (format (" unable to allocate %s buffer" , ggml_backend_buft_name (buft)));
62076206 }
6207+ bufs.emplace_back (buf);
62086208 buf_map.emplace (idx, buf);
62096209 }
62106210 }
62116211 else {
6212- buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx, buft);
6212+ ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft (ctx, buft);
62136213 if (buf == nullptr ) {
62146214 throw std::runtime_error (format (" unable to allocate %s buffer" , ggml_backend_buft_name (buft)));
62156215 }
@@ -6219,11 +6219,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
62196219 mlock_buf->init (ggml_backend_buffer_get_base (buf));
62206220 mlock_buf->grow_to (ggml_backend_buffer_get_size (buf));
62216221 }
6222+ bufs.emplace_back (buf);
62226223 for (uint32_t idx = 0 ; idx < ml.files .size (); idx++) {
62236224 buf_map.emplace (idx, buf);
62246225 }
62256226 }
6226- pimpl->ctxs_bufs .emplace_back (std::move (ctx_ptr), buf );
6227+ pimpl->ctxs_bufs .emplace_back (std::move (ctx_ptr), std::move (bufs) );
62276228
62286229 for (auto & buf : buf_map) {
62296230 // indicate that this buffer contains weights
@@ -6249,8 +6250,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
62496250 }
62506251
62516252 // print memory requirements per buffer type
6252- for (auto & [_, buf] : pimpl->ctxs_bufs ) {
6253- LLAMA_LOG_INFO (" %s: %12s model buffer size = %8.2f MiB\n " , __func__, ggml_backend_buffer_name (buf.get ()), ggml_backend_buffer_get_size (buf.get ()) / 1024.0 / 1024.0 );
6253+ for (auto & [_, bufs] : pimpl->ctxs_bufs ) {
6254+ for (auto & buf: bufs) {
6255+ LLAMA_LOG_INFO (" %s: %12s model buffer size = %8.2f MiB\n " ,
6256+ __func__, ggml_backend_buffer_name (buf.get ()), ggml_backend_buffer_get_size (buf.get ()) / 1024.0 / 1024.0 );
6257+ }
62546258 }
62556259
62566260 // populate tensors_by_name
@@ -6302,8 +6306,10 @@ size_t llama_model::n_devices() const {
63026306
63036307std::map<ggml_backend_buffer_type_t , size_t > llama_model::memory_breakdown () const {
63046308 std::map<ggml_backend_buffer_type_t , size_t > ret;
6305- for (const auto & [_, buf] : pimpl->ctxs_bufs ) {
6306- ret[ggml_backend_buffer_get_type (buf.get ())] += ggml_backend_buffer_get_size (buf.get ());
6309+ for (const auto & [_, bufs] : pimpl->ctxs_bufs ) {
6310+ for (const auto & buf : bufs) {
6311+ ret[ggml_backend_buffer_get_type (buf.get ())] += ggml_backend_buffer_get_size (buf.get ());
6312+ }
63076313 }
63086314 return ret;
63096315}
0 commit comments