Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions csrc/cpp_itfs/mha_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@ std::tuple<int, int> get_padded_hdim(int hdim_q, int hdim_v, std::string arch_id
{
if(hdim_q == 192 && hdim_v == 128 && arch_id == "gfx950")
return std::make_tuple(hdim_q, hdim_v);
assert(hdim_q == hdim_v);
if(hdim_q <= 64)
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line contains only whitespace. Remove the trailing whitespace for cleaner code formatting.

Suggested change

Copilot uses AI. Check for mistakes.
if(hdim_q == hdim_v)
{
return std::make_tuple(64, 64);
}
else if(hdim_q <= 128)
{
return std::make_tuple(128, 128);
}
else if(hdim_q <= 192)
{
return std::make_tuple(192, 192);
if(hdim_q <= 64)
{
return std::make_tuple(64, 64);
}
else if(hdim_q <= 128)
{
return std::make_tuple(128, 128);
}
else if(hdim_q <= 192)
{
return std::make_tuple(192, 192);
}
}

assert(false);
return std::make_tuple(hdim_q, hdim_v);
Copy link

Copilot AI Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When hdim_q != hdim_v and the special case (hdim_q == 192 && hdim_v == 128) is not met, the function returns unpadded dimensions without any validation or error handling. This silently allows unsupported dimension combinations to proceed, which could lead to runtime errors. Consider adding validation or error handling for unsupported dimension pairs.

Copilot uses AI. Check for mistakes.
}

Expand Down
4 changes: 2 additions & 2 deletions op_tests/cpp/mha/benchmark_mha_fwd.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2018-2026, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "ck_tile/ref/naive_attention.hpp"
#include "mha_fwd.h"
Expand Down Expand Up @@ -1144,7 +1144,7 @@ bool run(const ck_tile::ArgParser& arg_parser)

auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
return ck_tile::make_composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
else
return ck_tile::identity{};
Expand Down
Loading