Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 57 additions & 27 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3312,45 +3312,75 @@ FoldingRule RedundantAndShift() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());
uint32_t width = ElementWidth(type);
if ((width != 32) && (width != 64)) return false;
if (width != 8 && width != 16 && width != 32 && width != 64) return false;
const uint64_t width_mask =
(width == 64) ? ~0ull : ((1ull << width) - 1ull);

analysis::ConstantManager* const_mgr = context->get_constant_mgr();
const analysis::Constant* const_input1 = ConstInput(constants);
if (!const_input1) return false;
Instruction* other_inst = NonConstInput(context, constants[0], inst);

spv::Op other_op = other_inst->opcode();
if (other_op == spv::Op::OpShiftLeftLogical ||
other_op == spv::Op::OpShiftRightLogical) {
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);
if (other_op != spv::Op::OpShiftLeftLogical &&
other_op != spv::Op::OpShiftRightLogical) {
return false;
}

// Only valid if const is on the right
if (other_constants[0]) {
return false;
}
const analysis::Constant* const_input2 = other_constants[1];
if (!const_input2) return false;
std::vector<const analysis::Constant*> other_constants =
const_mgr->GetOperandConstants(other_inst);

bool can_convert_to_zero = true;
ForEachIntegerConstantPair(
const_mgr, const_input1, const_input2,
[&can_convert_to_zero, other_op](auto lhs, auto rhs) {
if (other_op == spv::Op::OpShiftRightLogical) {
can_convert_to_zero = can_convert_to_zero && (lhs << rhs) == 0;
} else {
can_convert_to_zero = can_convert_to_zero && (lhs >> rhs) == 0;
}
});
// Only valid if const is on the right.
if (other_constants[0]) return false;
const analysis::Constant* const_input2 = other_constants[1];
if (!const_input2) return false;

if (can_convert_to_zero) {
auto zero_id = context->get_constant_mgr()->GetNullConstId(type);
inst->SetOpcode(spv::Op::OpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}});
return true;
auto get_value_u64 =
[](const analysis::Constant* c) -> std::optional<uint64_t> {
if (!c) return std::nullopt;
const analysis::Integer* int_t = c->type()->AsInteger();
if (!int_t) return std::nullopt;
return c->GetZeroExtendedValue();
};

auto can_fold_component =
[&](const analysis::Constant* mask_const,
const analysis::Constant* shift_const) -> std::optional<bool> {
auto lhs = get_value_u64(mask_const);
auto rhs = get_value_u64(shift_const);
if (!lhs || !rhs) return std::nullopt;
if (*rhs >= width) return false;
uint64_t lhs_masked = *lhs & width_mask;
if (other_op == spv::Op::OpShiftRightLogical) {
return ((lhs_masked << *rhs) & width_mask) == 0;
}
return ((lhs_masked >> *rhs) & width_mask) == 0;
};

if (const analysis::Vector* mask_vec = type->AsVector()) {
const analysis::Vector* shift_vec = const_input2->type()->AsVector();
if (!shift_vec ||
shift_vec->element_count() != mask_vec->element_count()) {
return false;
}
const auto mask_components = const_input1->GetVectorComponents(const_mgr);
const auto shift_components =
const_input2->GetVectorComponents(const_mgr);
for (uint32_t i = 0; i != mask_vec->element_count(); ++i) {
auto result =
can_fold_component(mask_components[i], shift_components[i]);
if (!result || !*result) return false;
}
} else {
if (const_input2->type()->AsVector()) return false;
auto result = can_fold_component(const_input1, const_input2);
if (!result || !*result) return false;
}
return false;

auto zero_id = context->get_constant_mgr()->GetNullConstId(type);
inst->SetOpcode(spv::Op::OpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}});
return true;
};
}

Expand Down
103 changes: 100 additions & 3 deletions test/opt/fold_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ TEST_P(IntegerInstructionFoldingTest, Case) {
#define UINT_0_ID 109
#define INT_NULL_ID 110
#define UINT_NULL_ID 111
#define ULONG_NULL_ID 120
#define UBYTE_NULL_ID 121
#define USHORT_NULL_ID 122
#define V2USHORT_NULL_ID 123
#define HALF_3_ID 112
#define FLOAT_NULL_ID 113
const std::string& Header() {
Expand Down Expand Up @@ -280,6 +284,8 @@ OpName %main "main"
%_ptr_half = OpTypePointer Function %half
%_ptr_long = OpTypePointer Function %long
%_ptr_ulong = OpTypePointer Function %ulong
%_ptr_ubyte = OpTypePointer Function %ubyte
%_ptr_ushort = OpTypePointer Function %ushort
%_ptr_v2int = OpTypePointer Function %v2int
%_ptr_v4int = OpTypePointer Function %v4int
%_ptr_v4float = OpTypePointer Function %v4float
Expand All @@ -288,6 +294,7 @@ OpName %main "main"
%_ptr_struct_v2int_int_int = OpTypePointer Function %struct_v2int_int_int
%_ptr_v2float = OpTypePointer Function %v2float
%_ptr_v2double = OpTypePointer Function %v2double
%_ptr_v2ushort = OpTypePointer Function %v2ushort
%int_2 = OpConstant %int 2
%int_arr_2 = OpTypeArray %int %int_2
%short_n1 = OpConstant %short -1
Expand All @@ -306,7 +313,7 @@ OpName %main "main"
%ushort_1 = OpConstant %ushort 1
%ushort_2 = OpConstant %ushort 2
%ushort_3 = OpConstant %ushort 3
%ushort_null = OpConstantNull %ushort
%122 = OpConstantNull %ushort ; Need a def with an numerical id to define id maps.
%100 = OpConstant %int 0 ; Need a def with an numerical id to define id maps.
%110 = OpConstantNull %int ; Need a def with an numerical id to define id maps.
%103 = OpConstant %int 7 ; Need a def with an numerical id to define id maps.
Expand Down Expand Up @@ -341,6 +348,9 @@ OpName %main "main"
%ulong_4611686018427387904 = OpConstant %ulong 4611686018427387904
%109 = OpConstant %uint 0 ; Need a def with an numerical id to define id maps.
%111 = OpConstantNull %uint ; Need a def with an numerical id to define id maps.
%120 = OpConstantNull %ulong ; Need a def with an numerical id to define id maps.
%121 = OpConstantNull %ubyte ; Need a def with an numerical id to define id maps.
%123 = OpConstantNull %v2ushort ; Need a def with an numerical id to define id maps.
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
Expand Down Expand Up @@ -463,7 +473,7 @@ OpName %main "main"
%v4uint_1_0x0000ffff_uint_0_uint_max = OpConstantComposite %v4uint %uint_1 %uint_0x0000ffff %uint_0 %uint_max
%v2uint_1_null = OpConstantComposite %v2uint %uint_1 %111
%v2uint_null = OpConstantNull %v2uint
%v2ushort_1_null = OpConstantComposite %v2ushort %ushort_1 %ushort_null
%v2ushort_1_null = OpConstantComposite %v2ushort %ushort_1 %122
%v4ushort_0_1_2_3 = OpConstantComposite %v4ushort %ushort_0 %ushort_1 %ushort_2 %ushort_3
%v2ubyte_a_b = OpConstantComposite %v2ubyte %ubyte_a %ubyte_b
%v4ubyte_a_b_c_d = OpConstantComposite %v4ubyte %ubyte_a %ubyte_b %ubyte_c %ubyte_d
Expand Down Expand Up @@ -9419,7 +9429,94 @@ INSTANTIATE_TEST_SUITE_P(RedundantAndAddSubTest, MatchingInstructionFoldingTest,
"%2 = OpBitwiseAnd %uint %uint_4026531841 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0)
2, 0),

// Test case 15: Fold (mixed-width shift amount)
// 1u64 & (n << 1u32) = 0
InstructionFoldingCase<uint32_t>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_ulong Function\n" +
"%4 = OpLoad %ulong %n\n" +
"%3 = OpShiftLeftLogical %ulong %4 %uint_1\n" +
"%2 = OpBitwiseAnd %ulong %ulong_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, ULONG_NULL_ID),

// Test case 16: Fold (8-bit base, 32-bit shift amount)
// 1u8 & (n << 1u32) = 0
InstructionFoldingCase<uint32_t>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_ubyte Function\n" +
"%4 = OpLoad %ubyte %n\n" +
"%3 = OpShiftLeftLogical %ubyte %4 %uint_1\n" +
"%2 = OpBitwiseAnd %ubyte %ubyte_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, UBYTE_NULL_ID),

// Test case 17: Fold (16-bit base, 32-bit shift amount)
// 1u16 & (n << 1u32) = 0
InstructionFoldingCase<uint32_t>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_ushort Function\n" +
"%4 = OpLoad %ushort %n\n" +
"%3 = OpShiftLeftLogical %ushort %4 %uint_1\n" +
"%2 = OpBitwiseAnd %ushort %ushort_1 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, USHORT_NULL_ID),

// Test case 18: Fold (vector, mixed-width shift amount)
// <1,0>u16 & (n << <1,0>u32) = 0
InstructionFoldingCase<uint32_t>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v2ushort Function\n" +
"%4 = OpLoad %v2ushort %n\n" +
"%3 = OpShiftLeftLogical %v2ushort %4 %v2uint_1_null\n" +
"%2 = OpBitwiseAnd %v2ushort %v2ushort_1_null %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, V2USHORT_NULL_ID),

// Test case 19: Fold (8-bit base, 32-bit shift amount, right shift)
// 0x80u8 & (n >> 1u32) = 0
InstructionFoldingCase<uint32_t>(
Header() +
"%ubyte_128 = OpConstant %ubyte 128\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_ubyte Function\n" +
"%4 = OpLoad %ubyte %n\n" +
"%3 = OpShiftRightLogical %ubyte %4 %uint_1\n" +
"%2 = OpBitwiseAnd %ubyte %ubyte_128 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, UBYTE_NULL_ID),

// Test case 20: Fold (vector, mixed-width shift amount, right shift)
// <0x8000,0>u16 & (n >> <1,0>u32) = 0
InstructionFoldingCase<uint32_t>(
Header() +
"%ushort_32768 = OpConstant %ushort 32768\n" +
"%v2ushort_32768_0 = OpConstantComposite %v2ushort %ushort_32768 %ushort_0\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%n = OpVariable %_ptr_v2ushort Function\n" +
"%4 = OpLoad %v2ushort %n\n" +
"%3 = OpShiftRightLogical %v2ushort %4 %v2uint_1_null\n" +
"%2 = OpBitwiseAnd %v2ushort %v2ushort_32768_0 %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, V2USHORT_NULL_ID)
));

INSTANTIATE_TEST_SUITE_P(MergeAddTest, MatchingInstructionFoldingTest,
Expand Down