Skip to content

Commit fba4117

Browse files
authored
Fix Conv LHS packing padding/uninitialized ptrs V2 (#27215)
### Description Refer to V1 of the fix here: #27214 This PR includes all fixes from the V1 PR + logic to invalidate the lhs cache pointers in case the pad buffer's underlying buffer has changed due to a resize. The ARM team will look at potentially enhancing this logic after the 1.24.0 release. ### Motivation and Context Fix #26669
1 parent 711d155 commit fba4117

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,12 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
395395
auto lhs_ptrs = std::shared_ptr<const void*[]>(new const void*[lhs_ptrs_k * lhs_ptrs_m],
396396
std::default_delete<const void*[]>());
397397

398+
// Initialize all padding entries. For partial tiles (m < m_step),
399+
// the kai LHS packing kernel may still read pointer entries beyond the logically
400+
// filled 'm' positions. Leaving these uninitialized can cause non-deterministic
401+
// reads and corrupt packed LHS data.
402+
auto lhs_ptrs_ = lhs_ptrs.get();
403+
std::fill(lhs_ptrs_, lhs_ptrs_ + (lhs_ptrs_k * lhs_ptrs_m), reinterpret_cast<const void*>(&pad_ptr[0]));
398404

399405
auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1);
400406
auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1);
@@ -430,7 +436,6 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i
430436
};
431437

432438
size_t m_{0};
433-
auto lhs_ptrs_ = lhs_ptrs.get();
434439
for (size_t ih_ = 0; ih_ < ih_out_size; ih_ += sh) {
435440
for (size_t iw_ = 0; iw_ < iw_out_size; iw_ += sw, ++m_) {
436441
size_t k_{0};
@@ -460,7 +465,23 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
460465
// figure out how many blocks needed to correctly fill padding
461466
padsize = ((ci + padsize - 1) / padsize) * padsize;
462467
}
463-
static std::vector<float>pad_ptr(padsize, 0.f);
468+
469+
// pad_ptr must be at least 'ci' floats for padding pixels.
470+
// Using a thread_local grow-only buffer to avoid cross-thread interference and ensure sizing is correct.
471+
thread_local std::vector<float> pad_ptr;
472+
const float* old_pad_ptr = pad_ptr.data();
473+
bool has_pad_ptr_changed = false;
474+
475+
if (pad_ptr.size() < padsize) {
476+
pad_ptr.resize(padsize, 0.f);
477+
if (pad_ptr.data() != old_pad_ptr) {
478+
has_pad_ptr_changed = true;
479+
}
480+
} else {
481+
// Ensure any previously-used region remains zeroed (grow-only means it should already be zeros,
482+
// but keep this explicit for safety).
483+
std::fill(pad_ptr.begin(), pad_ptr.end(), 0.f);
484+
}
464485

465486
LhsCacheKey key = {
466487
ci, ih, iw,
@@ -481,6 +502,16 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s
481502
// Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions.
482503
thread_local std::unordered_map<LhsCacheKey, std::shared_ptr<const void*[]>> lhs_ptrs_cache;
483504

505+
if (has_pad_ptr_changed)
506+
{
507+
// If the pad buffer was resized and a re-allocation has occurred, the cached lhs ptrs are invalid as they
508+
// would be referencing the old pad buffer.
509+
// See discussion in https://github.com/microsoft/onnxruntime/pull/27214.
510+
// TODO(hasesh / JonathanC-ARM): A better approach would be to include the pad buffer address in the cache key
511+
// or any other approach that would reduce unnecessary cache invalidations.
512+
lhs_ptrs_cache.clear();
513+
}
514+
484515
std::shared_ptr<const void*[]> lhs_ptrs;
485516
if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {
486517
lhs_ptrs = found->second;

0 commit comments

Comments
 (0)