From e24af6fa154d8e25fb448633595c65c32ba9edb2 Mon Sep 17 00:00:00 2001 From: spencer <94135891+spencer005@users.noreply.github.com> Date: Wed, 21 Jan 2026 02:57:16 -0500 Subject: [PATCH] spirv-opt: handle mixed-width shifts in RedundantAndShift I encountered a crash with instruction foldings enabled while using shaderc (test case 15), this commit adds support for instruction folding for all different cases I interpreted as valid while reading the spirv spec. https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html 3.3.14. Bit Instructions --- source/opt/folding_rules.cpp | 84 +++++++++++++++++++--------- test/opt/fold_test.cpp | 103 ++++++++++++++++++++++++++++++++++- 2 files changed, 157 insertions(+), 30 deletions(-) diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index a20d9040d4..bfc6d8804a 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -3312,7 +3312,9 @@ FoldingRule RedundantAndShift() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); uint32_t width = ElementWidth(type); - if ((width != 32) && (width != 64)) return false; + if (width != 8 && width != 16 && width != 32 && width != 64) return false; + const uint64_t width_mask = + (width == 64) ? ~0ull : ((1ull << width) - 1ull); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); const analysis::Constant* const_input1 = ConstInput(constants); @@ -3320,37 +3322,65 @@ FoldingRule RedundantAndShift() { Instruction* other_inst = NonConstInput(context, constants[0], inst); spv::Op other_op = other_inst->opcode(); - if (other_op == spv::Op::OpShiftLeftLogical || - other_op == spv::Op::OpShiftRightLogical) { - std::vector other_constants = - const_mgr->GetOperandConstants(other_inst); + if (other_op != spv::Op::OpShiftLeftLogical && + other_op != spv::Op::OpShiftRightLogical) { + return false; + } - // Only valid if const is on the right - if (other_constants[0]) { - return false; - } - const analysis::Constant* const_input2 = other_constants[1]; - if (!const_input2) return false; + std::vector other_constants = + const_mgr->GetOperandConstants(other_inst); - bool can_convert_to_zero = true; - ForEachIntegerConstantPair( - const_mgr, const_input1, const_input2, - [&can_convert_to_zero, other_op](auto lhs, auto rhs) { - if (other_op == spv::Op::OpShiftRightLogical) { - can_convert_to_zero = can_convert_to_zero && (lhs << rhs) == 0; - } else { - can_convert_to_zero = can_convert_to_zero && (lhs >> rhs) == 0; - } - }); + // Only valid if const is on the right. + if (other_constants[0]) return false; + const analysis::Constant* const_input2 = other_constants[1]; + if (!const_input2) return false; - if (can_convert_to_zero) { - auto zero_id = context->get_constant_mgr()->GetNullConstId(type); - inst->SetOpcode(spv::Op::OpCopyObject); - inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}}); - return true; + auto get_value_u64 = + [](const analysis::Constant* c) -> std::optional { + if (!c) return std::nullopt; + const analysis::Integer* int_t = c->type()->AsInteger(); + if (!int_t) return std::nullopt; + return c->GetZeroExtendedValue(); + }; + + auto can_fold_component = + [&](const analysis::Constant* mask_const, + const analysis::Constant* shift_const) -> std::optional { + auto lhs = get_value_u64(mask_const); + auto rhs = get_value_u64(shift_const); + if (!lhs || !rhs) return std::nullopt; + if (*rhs >= width) return false; + uint64_t lhs_masked = *lhs & width_mask; + if (other_op == spv::Op::OpShiftRightLogical) { + return ((lhs_masked << *rhs) & width_mask) == 0; } + return ((lhs_masked >> *rhs) & width_mask) == 0; + }; + + if (const analysis::Vector* mask_vec = type->AsVector()) { + const analysis::Vector* shift_vec = const_input2->type()->AsVector(); + if (!shift_vec || + shift_vec->element_count() != mask_vec->element_count()) { + return false; + } + const auto mask_components = const_input1->GetVectorComponents(const_mgr); + const auto shift_components = + const_input2->GetVectorComponents(const_mgr); + for (uint32_t i = 0; i != mask_vec->element_count(); ++i) { + auto result = + can_fold_component(mask_components[i], shift_components[i]); + if (!result || !*result) return false; + } + } else { + if (const_input2->type()->AsVector()) return false; + auto result = can_fold_component(const_input1, const_input2); + if (!result || !*result) return false; } - return false; + + auto zero_id = context->get_constant_mgr()->GetNullConstId(type); + inst->SetOpcode(spv::Op::OpCopyObject); + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}}); + return true; }; } diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 8c7649bb90..1b6d8403c5 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -211,6 +211,10 @@ TEST_P(IntegerInstructionFoldingTest, Case) { #define UINT_0_ID 109 #define INT_NULL_ID 110 #define UINT_NULL_ID 111 +#define ULONG_NULL_ID 120 +#define UBYTE_NULL_ID 121 +#define USHORT_NULL_ID 122 +#define V2USHORT_NULL_ID 123 #define HALF_3_ID 112 #define FLOAT_NULL_ID 113 const std::string& Header() { @@ -280,6 +284,8 @@ OpName %main "main" %_ptr_half = OpTypePointer Function %half %_ptr_long = OpTypePointer Function %long %_ptr_ulong = OpTypePointer Function %ulong +%_ptr_ubyte = OpTypePointer Function %ubyte +%_ptr_ushort = OpTypePointer Function %ushort %_ptr_v2int = OpTypePointer Function %v2int %_ptr_v4int = OpTypePointer Function %v4int %_ptr_v4float = OpTypePointer Function %v4float @@ -288,6 +294,7 @@ OpName %main "main" %_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int %_ptr_v2float = OpTypePointer Function %v2float %_ptr_v2double = OpTypePointer Function %v2double +%_ptr_v2ushort = OpTypePointer Function %v2ushort %int_2 = OpConstant %int 2 %int_arr_2 = OpTypeArray %int %int_2 %short_n1 = OpConstant %short -1 @@ -306,7 +313,7 @@ OpName %main "main" %ushort_1 = OpConstant %ushort 1 %ushort_2 = OpConstant %ushort 2 %ushort_3 = OpConstant %ushort 3 -%ushort_null = OpConstantNull %ushort +%122 = OpConstantNull %ushort ; Need a def with an numerical id to define id maps. %100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps. %110 = OpConstantNull %int ; Need a def with an numerical id to define id maps. %103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps. @@ -341,6 +348,9 @@ OpName %main "main" %ulong_4611686018427387904 = OpConstant %ulong 4611686018427387904 %109 = OpConstant %uint 0 ; Need a def with an numerical id to define id maps. %111 = OpConstantNull %uint ; Need a def with an numerical id to define id maps. +%120 = OpConstantNull %ulong ; Need a def with an numerical id to define id maps. +%121 = OpConstantNull %ubyte ; Need a def with an numerical id to define id maps. +%123 = OpConstantNull %v2ushort ; Need a def with an numerical id to define id maps. %uint_0 = OpConstant %uint 0 %uint_1 = OpConstant %uint 1 %uint_2 = OpConstant %uint 2 @@ -463,7 +473,7 @@ OpName %main "main" %v4uint_1_0x0000ffff_uint_0_uint_max = OpConstantComposite %v4uint %uint_1 %uint_0x0000ffff %uint_0 %uint_max %v2uint_1_null = OpConstantComposite %v2uint %uint_1 %111 %v2uint_null = OpConstantNull %v2uint -%v2ushort_1_null = OpConstantComposite %v2ushort %ushort_1 %ushort_null +%v2ushort_1_null = OpConstantComposite %v2ushort %ushort_1 %122 %v4ushort_0_1_2_3 = OpConstantComposite %v4ushort %ushort_0 %ushort_1 %ushort_2 %ushort_3 %v2ubyte_a_b = OpConstantComposite %v2ubyte %ubyte_a %ubyte_b %v4ubyte_a_b_c_d = OpConstantComposite %v4ubyte %ubyte_a %ubyte_b %ubyte_c %ubyte_d @@ -9419,7 +9429,94 @@ INSTANTIATE_TEST_SUITE_P(RedundantAndAddSubTest, MatchingInstructionFoldingTest, "%2 = OpBitwiseAnd %uint %uint_4026531841 %3\n" + "OpReturn\n" + "OpFunctionEnd", - 2, 0) + 2, 0), + + // Test case 15: Fold (mixed-width shift amount) + // 1u64 & (n << 1u32) = 0 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_ulong Function\n" + + "%4 = OpLoad %ulong %n\n" + + "%3 = OpShiftLeftLogical %ulong %4 %uint_1\n" + + "%2 = OpBitwiseAnd %ulong %ulong_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, ULONG_NULL_ID), + + // Test case 16: Fold (8-bit base, 32-bit shift amount) + // 1u8 & (n << 1u32) = 0 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_ubyte Function\n" + + "%4 = OpLoad %ubyte %n\n" + + "%3 = OpShiftLeftLogical %ubyte %4 %uint_1\n" + + "%2 = OpBitwiseAnd %ubyte %ubyte_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, UBYTE_NULL_ID), + + // Test case 17: Fold (16-bit base, 32-bit shift amount) + // 1u16 & (n << 1u32) = 0 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_ushort Function\n" + + "%4 = OpLoad %ushort %n\n" + + "%3 = OpShiftLeftLogical %ushort %4 %uint_1\n" + + "%2 = OpBitwiseAnd %ushort %ushort_1 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, USHORT_NULL_ID), + + // Test case 18: Fold (vector, mixed-width shift amount) + // <1,0>u16 & (n << <1,0>u32) = 0 + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2ushort Function\n" + + "%4 = OpLoad %v2ushort %n\n" + + "%3 = OpShiftLeftLogical %v2ushort %4 %v2uint_1_null\n" + + "%2 = OpBitwiseAnd %v2ushort %v2ushort_1_null %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, V2USHORT_NULL_ID), + + // Test case 19: Fold (8-bit base, 32-bit shift amount, right shift) + // 0x80u8 & (n >> 1u32) = 0 + InstructionFoldingCase( + Header() + + "%ubyte_128 = OpConstant %ubyte 128\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_ubyte Function\n" + + "%4 = OpLoad %ubyte %n\n" + + "%3 = OpShiftRightLogical %ubyte %4 %uint_1\n" + + "%2 = OpBitwiseAnd %ubyte %ubyte_128 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, UBYTE_NULL_ID), + + // Test case 20: Fold (vector, mixed-width shift amount, right shift) + // <0x8000,0>u16 & (n >> <1,0>u32) = 0 + InstructionFoldingCase( + Header() + + "%ushort_32768 = OpConstant %ushort 32768\n" + + "%v2ushort_32768_0 = OpConstantComposite %v2ushort %ushort_32768 %ushort_0\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%n = OpVariable %_ptr_v2ushort Function\n" + + "%4 = OpLoad %v2ushort %n\n" + + "%3 = OpShiftRightLogical %v2ushort %4 %v2uint_1_null\n" + + "%2 = OpBitwiseAnd %v2ushort %v2ushort_32768_0 %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 2, V2USHORT_NULL_ID) )); INSTANTIATE_TEST_SUITE_P(MergeAddTest, MatchingInstructionFoldingTest,