Skip to content

Commit c056c09

Browse files
authored
[CPU] Refactor match_conv_mul_add_fq util function (#34117)
### Details: Remove NOLINT and template instantiations ### Tickets: - N/A
1 parent 2ea24b1 commit c056c09

File tree

2 files changed

+27
-33
lines changed

2 files changed

+27
-33
lines changed

src/plugins/intel_cpu/src/transformations/utils.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include "openvino/core/node.hpp"
1414
#include "openvino/core/shape.hpp"
15-
#include "openvino/core/type/element_type.hpp"
1615
#include "openvino/op/add.hpp"
1716
#include "openvino/op/convolution.hpp"
1817
#include "openvino/op/fake_quantize.hpp"
@@ -22,41 +21,10 @@
2221
#include "openvino/pass/pattern/op/pattern.hpp"
2322
#include "openvino/pass/pattern/op/wrap_type.hpp"
2423

25-
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
26-
# include "openvino/op/subtract.hpp" // NOLINT(misc-include-cleaner) needed for explicit template instantiation
27-
#endif
28-
2924
using namespace ov::pass::pattern;
3025

3126
namespace ov::intel_cpu {
3227

33-
template <class T>
34-
bool match_conv_mul_add_fq(const std::shared_ptr<const ov::Node>& node) {
35-
auto conv_m = wrap_type<ov::op::v1::Convolution>(
36-
{any_input(type_matches_any({ov::element::i8, ov::element::u8})), any_input()});
37-
auto mul0_m = wrap_type<ov::op::v1::Multiply>({conv_m, any_input()});
38-
auto add_m = wrap_type<ov::op::v1::Add>({mul0_m, any_input()});
39-
auto fq_m = wrap_type<ov::op::v0::FakeQuantize>({add_m, any_input(), any_input(), any_input(), any_input()},
40-
type_matches_any({ov::element::i8, ov::element::u8}));
41-
auto final_m = wrap_type<T>({fq_m, any_input()});
42-
43-
auto matcher = std::make_shared<Matcher>(final_m);
44-
if (!matcher->match(std::const_pointer_cast<ov::Node>(node))) {
45-
return false;
46-
}
47-
48-
const auto& pattern_map = matcher->get_pattern_value_map();
49-
const auto fq = pattern_map.at(fq_m).get_node_shared_ptr();
50-
const auto conv = pattern_map.at(conv_m).get_node_shared_ptr();
51-
52-
return conv->get_input_element_type(0) == fq->get_output_element_type(0);
53-
}
54-
55-
#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)
56-
template bool match_conv_mul_add_fq<ov::op::v1::Subtract>(const std::shared_ptr<const ov::Node>& node);
57-
template bool match_conv_mul_add_fq<ov::op::v1::Multiply>(const std::shared_ptr<const ov::Node>& node);
58-
#endif
59-
6028
bool match_fq_mul_conv_bias_same_types(const std::shared_ptr<const ov::Node>& node, FQMulAddPattern pattern) {
6129
auto convMulAdd_conv = wrap_type<ov::op::v1::Convolution>();
6230
auto convMulAdd_mul = wrap_type<ov::op::v1::Multiply>({convMulAdd_conv, any_input()});

src/plugins/intel_cpu/src/transformations/utils.hpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#pragma once
66

77
#include <cstdint>
8+
#include <type_traits>
89

910
#include "openvino/core/model.hpp"
1011
#include "openvino/core/shape.hpp"
@@ -13,6 +14,7 @@
1314
#include "openvino/op/convolution.hpp"
1415
#include "openvino/op/fake_quantize.hpp"
1516
#include "openvino/op/multiply.hpp"
17+
#include "openvino/op/subtract.hpp"
1618
#include "openvino/pass/pattern/matcher.hpp"
1719
#include "openvino/pass/pattern/op/label.hpp"
1820
#include "openvino/pass/pattern/op/pattern.hpp"
@@ -21,7 +23,31 @@
2123
namespace ov::intel_cpu {
2224

2325
template <class T>
24-
bool match_conv_mul_add_fq(const std::shared_ptr<const ov::Node>& node);
26+
bool match_conv_mul_add_fq(const std::shared_ptr<const ov::Node>& node) {
27+
static_assert(std::is_same_v<T, ov::op::v1::Subtract> || std::is_same_v<T, ov::op::v1::Multiply>,
28+
"match_conv_mul_add_fq supports only Subtract and Multiply");
29+
30+
using namespace ov::pass::pattern;
31+
32+
auto conv_m = wrap_type<ov::op::v1::Convolution>(
33+
{any_input(type_matches_any({ov::element::i8, ov::element::u8})), any_input()});
34+
auto mul0_m = wrap_type<ov::op::v1::Multiply>({conv_m, any_input()});
35+
auto add_m = wrap_type<ov::op::v1::Add>({mul0_m, any_input()});
36+
auto fq_m = wrap_type<ov::op::v0::FakeQuantize>({add_m, any_input(), any_input(), any_input(), any_input()},
37+
type_matches_any({ov::element::i8, ov::element::u8}));
38+
auto final_m = wrap_type<T>({fq_m, any_input()});
39+
40+
auto matcher = std::make_shared<ov::pass::pattern::Matcher>(final_m);
41+
if (!matcher->match(std::const_pointer_cast<ov::Node>(node))) {
42+
return false;
43+
}
44+
45+
const auto& pattern_map = matcher->get_pattern_value_map();
46+
const auto fq = pattern_map.at(fq_m).get_node_shared_ptr();
47+
const auto conv = pattern_map.at(conv_m).get_node_shared_ptr();
48+
49+
return conv->get_input_element_type(0) == fq->get_output_element_type(0);
50+
}
2551

2652
enum class FQMulAddPattern : std::uint8_t { ConvMulAdd, ConvAddMul };
2753

0 commit comments

Comments
 (0)