Skip to content

Commit f83d4d0

Browse files
quic-calvnguyquic_calvnguy
andauthored
[QNN-EP] Implement file mapped weights feature (#26952)
Description Enables the file mapping of weights as well as the overall context bin. This feature is currently only enabled for ARM64 WIN devices Motivation and Context Currently, when reading the context bin, ORT allocates a large buffer on the heap. Assuming the same model is used, each ORT session will allocate a buffer for the context bin. This is incredibly wasteful when large models are used. Instead, WIN file mapping can be leveraged to map the context bin, then every time a context needs to be created with the context bin, the pointer to the context bin can be retrieved and used instead of some pre-allocated buffer, thus making QNN EP more memory-efficient. In the case of multiple ORT sessions, the context bin will only be loaded once for all sessions, increasing memory efficiency and overall initialization performance. This is very useful regarding the use of LLMs going forward. --------- Co-authored-by: quic_calvnguy <quic_calvnguy@quic_inc.com>
1 parent 2ece1c1 commit f83d4d0

14 files changed

+801
-52
lines changed

onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
9494
const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, "");
9595
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
9696
static_cast<uint64_t>(context_binary.length()),
97+
"",
9798
main_context_node.Name(),
9899
qnn_models,
99100
max_spill_fill_size);
@@ -127,6 +128,18 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
127128
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context does not exist or is not accessible.");
128129
}
129130

131+
std::string context_binary_path_str = context_binary_path.string();
132+
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
133+
if (qnn_backend_manager->FileMappingIsEnabled()) {
134+
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(nullptr,
135+
0,
136+
context_binary_path_str,
137+
main_context_node.Name(),
138+
qnn_models,
139+
max_spill_fill_size);
140+
}
141+
#endif
142+
130143
size_t buffer_size{0};
131144
std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary);
132145
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
@@ -144,6 +157,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
144157
cache_file.close();
145158
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
146159
static_cast<uint64_t>(buffer_size),
160+
context_binary_path_str,
147161
main_context_node.Name(),
148162
qnn_models,
149163
max_spill_fill_size);

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

Lines changed: 356 additions & 40 deletions
Large diffs are not rendered by default.

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,18 @@
2525
#include "System/QnnSystemInterface.h"
2626

2727
#include "core/providers/qnn/ort_api.h"
28+
#include "core/providers/qnn/rpcmem_library.h"
2829
#include "core/providers/qnn/builder/op_builder_factory.h"
2930
#include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h"
3031
#include "core/providers/qnn/builder/qnn_def.h"
3132
#include "core/providers/qnn/builder/qnn_htp_power_config_manager.h"
3233
#include "core/providers/qnn/builder/qnn_profile_serializer.h"
3334
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
3435

36+
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
37+
#include "core/providers/qnn/builder/qnn_file_mapping_interface.h"
38+
#endif
39+
3540
namespace onnxruntime {
3641
namespace qnn {
3742

@@ -154,6 +159,7 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
154159
std::unique_ptr<unsigned char[]> GetContextBinaryBuffer(uint64_t& written_buffer_size);
155160

156161
Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
162+
const std::string& context_bin_filepath,
157163
std::string node_name,
158164
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
159165
int64_t max_spill_fill_size);
@@ -163,6 +169,8 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
163169
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context,
164170
bool need_load_system_lib, bool share_ep_contexts,
165171
bool enable_vtcm_backup_buffer_sharing,
172+
bool enable_file_mapped_weights,
173+
std::shared_ptr<qnn::RpcMemLibrary> rpcmem_library,
166174
std::unordered_map<std::string, std::unique_ptr<std::vector<std::string>>>& context_bin_map);
167175

168176
Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
@@ -248,9 +256,34 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
248256
bool ProfilingEnabled() { return profiling_enabled_; }
249257
#endif
250258

259+
bool FileMappingIsEnabled() {
260+
return file_mapped_weights_enabled_;
261+
}
262+
263+
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
264+
Qnn_ErrorHandle_t MapDmaData(Qnn_ContextBinaryDataRequest_t request,
265+
Qnn_ContextBinaryDmaDataResponse_t* response,
266+
void* const mapped_base_ptr,
267+
const size_t file_size);
268+
269+
Qnn_ErrorHandle_t ReleaseDmaData(Qnn_ContextBinaryDmaDataMem_t data_mem, void* mapped_base_ptr);
270+
#endif
271+
251272
QnnLog_Level_t MapOrtSeverityToQNNLogLevel(logging::Severity ort_log_level);
252273
static logging::Severity MapQNNLogLevelToOrtSeverity(QnnLog_Level_t qnn_log_level);
253274

275+
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
276+
typedef struct FileMappingCallbackInfo {
277+
void* const mapped_file_ptr;
278+
const size_t file_size;
279+
QnnBackendManager* const backend_manager;
280+
281+
FileMappingCallbackInfo(void* ptr, size_t size, QnnBackendManager* manager)
282+
: mapped_file_ptr(ptr), file_size(size), backend_manager(manager) {}
283+
284+
} FileMappingCallbackInfo_t;
285+
#endif
286+
254287
private:
255288
Status LoadBackend();
256289

@@ -268,9 +301,24 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
268301

269302
Status CreateContext(bool enable_htp_weight_sharing);
270303

304+
Status GetFileSizeIfValid(const std::string& filepath, size_t& file_size);
305+
306+
Status ReadContextBinIfValid(const std::string& context_bin_filepath,
307+
std::vector<char>& buffer);
308+
271309
Status CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map<std::string,
272310
std::unique_ptr<std::vector<std::string>>>& context_bin_map);
273311

312+
Status CreateContextFromListAsync(const QnnContext_Config_t** configs,
313+
std::unordered_map<std::string,
314+
std::unique_ptr<std::vector<std::string>>>& context_bin_map);
315+
316+
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
317+
Status CreateContextFromListAsyncWithCallback(const QnnContext_Config_t** configs,
318+
std::unordered_map<std::string,
319+
std::unique_ptr<std::vector<std::string>>>& context_bin_map);
320+
#endif
321+
274322
Status ReleaseContext();
275323

276324
// Sets the ORT logger and creates a corresponding QNN logger with the same log level.
@@ -455,6 +503,15 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
455503
bool context_created_ = false;
456504
bool backend_setup_completed_ = false;
457505
bool vtcm_backup_buffer_sharing_enabled_ = false;
506+
bool file_mapped_weights_enabled_ = false;
507+
508+
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
509+
std::unique_ptr<FileMappingInterface> file_mapper_ = nullptr;
510+
// Notify params for file mapping must persist throughout lifetime of
511+
// QnnBackendManager for release of DMA data callback on destruction
512+
std::vector<std::unique_ptr<FileMappingCallbackInfo_t>> file_mapping_notify_params_;
513+
#endif
514+
458515
// NPU backend requires quantized model
459516
QnnBackendType qnn_backend_type_ = QnnBackendType::CPU;
460517
Qnn_ProfileHandle_t profile_backend_handle_ = nullptr;
@@ -473,6 +530,8 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
473530
// Mapping of thread id to on-run-start/end power configs
474531
std::mutex per_thread_power_configs_mutex_;
475532
std::unordered_map<std::thread::id, PerThreadHtpPowerConfigs_t> per_thread_power_configs_;
533+
534+
std::shared_ptr<qnn::RpcMemLibrary> rpcmem_library_ = nullptr;
476535
};
477536

478537
} // namespace qnn

onnxruntime/core/providers/qnn/builder/qnn_def.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ namespace qnn {
1919
#define QNN_SYSTEM_PROFILE_API_ENABLED
2020
#endif
2121

22+
#if defined(_WIN32) && (defined(__aarch64__) || defined(_M_ARM64))
23+
#if QNN_API_VERSION_MAJOR > 2 || ((QNN_API_VERSION_MAJOR) == 2 && (QNN_API_VERSION_MINOR >= 32))
24+
#define QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
25+
#endif
26+
#endif
27+
2228
// QNN only support subset of POSIX of dlopen/dlsym/dladdr/dlerror/dlclose
2329
// except the following flags for dlopen, others should be done only
2430
// when we really need them
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <string>
7+
8+
#include <QnnContext.h>
9+
10+
#include "core/providers/qnn/ort_api.h"
11+
#include "core/providers/qnn/builder/qnn_def.h"
12+
13+
namespace onnxruntime {
14+
namespace qnn {
15+
16+
class FileMappingInterface {
17+
public:
18+
virtual ~FileMappingInterface() = default;
19+
20+
virtual Status GetContextBinMappedMemoryPtr(const std::string& bin_filepath,
21+
void** mapped_data_ptr) = 0;
22+
};
23+
24+
} // namespace qnn
25+
} // namespace onnxruntime
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/qnn_windows_file_mapper.h"
5+
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
6+
7+
#include <wil/filesystem.h>
8+
9+
#include <utility>
10+
11+
#include <QnnContext.h>
12+
13+
#include "core/providers/qnn/ort_api.h"
14+
15+
namespace onnxruntime {
16+
namespace qnn {
17+
18+
WindowsFileMapper::WindowsFileMapper(const logging::Logger& logger)
19+
: logger_(&logger) {
20+
}
21+
22+
WindowsFileMapper::~WindowsFileMapper() {
23+
}
24+
25+
static void UnmapFile(void* addr) noexcept {
26+
bool successful = UnmapViewOfFile(addr);
27+
if (!successful) {
28+
const auto error_code = GetLastError();
29+
LOGS_DEFAULT(ERROR) << "Failed to unmap view of file with ptr: " << addr
30+
<< ", Error code: " << error_code << ", \""
31+
<< std::system_category().message(error_code) << "\"";
32+
}
33+
}
34+
35+
Status WindowsFileMapper::GetContextBinMappedMemoryPtr(const std::string& bin_filepath,
36+
void** mapped_data_ptr) {
37+
LOGS(*logger_, INFO) << "Creating context bin file mapping for "
38+
<< bin_filepath;
39+
40+
ORT_RETURN_IF(bin_filepath.empty(), "Context bin file path is empty");
41+
42+
std::lock_guard<std::mutex> lock(map_mutex_);
43+
auto map_it = mapped_memory_ptrs_.find(bin_filepath);
44+
if (map_it != mapped_memory_ptrs_.end()) {
45+
*mapped_data_ptr = map_it->second.get();
46+
LOGS(*logger_, INFO) << "Found existing mapview memory pointer (" << mapped_data_ptr
47+
<< ") for context bin file: " << bin_filepath;
48+
return Status::OK();
49+
}
50+
51+
std::wstring bin_filepath_wstr(bin_filepath.begin(), bin_filepath.end());
52+
wil::unique_hfile file_handle{CreateFile2(bin_filepath_wstr.c_str(),
53+
GENERIC_READ,
54+
FILE_SHARE_READ,
55+
OPEN_EXISTING,
56+
NULL)};
57+
if (file_handle.get() == INVALID_HANDLE_VALUE) {
58+
const auto error_code = GetLastError();
59+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
60+
"Failed to create file handle for context bin", bin_filepath,
61+
". Error code: ", error_code, ", \"",
62+
std::system_category().message(error_code), "\"");
63+
}
64+
65+
LOGS(*logger_, VERBOSE) << "Created file handle (" << file_handle.get() << ") for context bin: "
66+
<< bin_filepath;
67+
68+
wil::unique_hfile file_mapping_handle{CreateFileMappingW(file_handle.get(),
69+
nullptr,
70+
PAGE_READONLY,
71+
0x00,
72+
0x00,
73+
nullptr)};
74+
if (file_mapping_handle.get() == INVALID_HANDLE_VALUE) {
75+
const auto error_code = GetLastError();
76+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
77+
"Failed to create file mapping handle for context bin",
78+
bin_filepath, ". Error code: ", error_code, ", \"",
79+
std::system_category().message(error_code), "\"");
80+
}
81+
82+
LOGS(*logger_, VERBOSE) << "Created file mapping with handle (" << file_mapping_handle.get()
83+
<< ") for context bin:" << bin_filepath;
84+
85+
void* const mapped_base_ptr = MapViewOfFile(file_mapping_handle.get(),
86+
FILE_MAP_READ,
87+
0, 0, 0);
88+
89+
if (mapped_base_ptr == nullptr) {
90+
const auto error_code = GetLastError();
91+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
92+
"Failed to retrieve mapview pointer for context bin",
93+
bin_filepath, ". Error code: ", error_code, ", \"",
94+
std::system_category().message(error_code), "\"");
95+
}
96+
97+
LOGS(*logger_, INFO) << "Created mapview pointer with address " << mapped_base_ptr
98+
<< " for context bin " << bin_filepath;
99+
100+
onnxruntime::Env::MappedMemoryPtr mapped_memory_ptr{reinterpret_cast<char*>(mapped_base_ptr),
101+
[mapped_base_ptr](void*) {
102+
UnmapFile(mapped_base_ptr);
103+
}};
104+
105+
*mapped_data_ptr = mapped_memory_ptr.get();
106+
mapped_memory_ptrs_.emplace(bin_filepath, std::move(mapped_memory_ptr));
107+
108+
return Status::OK();
109+
}
110+
} // namespace qnn
111+
} // namespace onnxruntime
112+
113+
#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/qnn/builder/qnn_file_mapping_interface.h"
7+
#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE
8+
9+
#include <memory>
10+
#include <mutex>
11+
#include <string>
12+
#include <unordered_map>
13+
14+
#include <QnnContext.h>
15+
16+
#include "core/providers/qnn/ort_api.h"
17+
18+
namespace onnxruntime {
19+
namespace qnn {
20+
21+
class WindowsFileMapper : public FileMappingInterface {
22+
public:
23+
explicit WindowsFileMapper(const logging::Logger& logger);
24+
~WindowsFileMapper() override;
25+
26+
// Creates a file mapping of the context binary and returns the
27+
// mapview pointer of the file mapping
28+
Status GetContextBinMappedMemoryPtr(const std::string& bin_filepath,
29+
void** mapped_data_ptr) override;
30+
31+
private:
32+
// A container of smart pointers of mapview memory pointers to mapped context bins
33+
// key: filepath to context bin, value: smart pointer of mapview memory pointers
34+
std::mutex map_mutex_;
35+
std::unordered_map<std::string, onnxruntime::Env::MappedMemoryPtr> mapped_memory_ptrs_;
36+
const logging::Logger* logger_;
37+
};
38+
39+
} // namespace qnn
40+
} // namespace onnxruntime
41+
42+
#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE

0 commit comments

Comments
 (0)