|
12 | 12 |
|
13 | 13 | #include "openvino/core/node.hpp" |
14 | 14 | #include "openvino/core/shape.hpp" |
15 | | -#include "openvino/core/type/element_type.hpp" |
16 | 15 | #include "openvino/op/add.hpp" |
17 | 16 | #include "openvino/op/convolution.hpp" |
18 | 17 | #include "openvino/op/fake_quantize.hpp" |
|
22 | 21 | #include "openvino/pass/pattern/op/pattern.hpp" |
23 | 22 | #include "openvino/pass/pattern/op/wrap_type.hpp" |
24 | 23 |
|
25 | | -#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) |
26 | | -# include "openvino/op/subtract.hpp" // NOLINT(misc-include-cleaner) needed for explicit template instantiation |
27 | | -#endif |
28 | | - |
29 | 24 | using namespace ov::pass::pattern; |
30 | 25 |
|
31 | 26 | namespace ov::intel_cpu { |
32 | 27 |
|
33 | | -template <class T> |
34 | | -bool match_conv_mul_add_fq(const std::shared_ptr<const ov::Node>& node) { |
35 | | - auto conv_m = wrap_type<ov::op::v1::Convolution>( |
36 | | - {any_input(type_matches_any({ov::element::i8, ov::element::u8})), any_input()}); |
37 | | - auto mul0_m = wrap_type<ov::op::v1::Multiply>({conv_m, any_input()}); |
38 | | - auto add_m = wrap_type<ov::op::v1::Add>({mul0_m, any_input()}); |
39 | | - auto fq_m = wrap_type<ov::op::v0::FakeQuantize>({add_m, any_input(), any_input(), any_input(), any_input()}, |
40 | | - type_matches_any({ov::element::i8, ov::element::u8})); |
41 | | - auto final_m = wrap_type<T>({fq_m, any_input()}); |
42 | | - |
43 | | - auto matcher = std::make_shared<Matcher>(final_m); |
44 | | - if (!matcher->match(std::const_pointer_cast<ov::Node>(node))) { |
45 | | - return false; |
46 | | - } |
47 | | - |
48 | | - const auto& pattern_map = matcher->get_pattern_value_map(); |
49 | | - const auto fq = pattern_map.at(fq_m).get_node_shared_ptr(); |
50 | | - const auto conv = pattern_map.at(conv_m).get_node_shared_ptr(); |
51 | | - |
52 | | - return conv->get_input_element_type(0) == fq->get_output_element_type(0); |
53 | | -} |
54 | | - |
55 | | -#if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) |
56 | | -template bool match_conv_mul_add_fq<ov::op::v1::Subtract>(const std::shared_ptr<const ov::Node>& node); |
57 | | -template bool match_conv_mul_add_fq<ov::op::v1::Multiply>(const std::shared_ptr<const ov::Node>& node); |
58 | | -#endif |
59 | | - |
60 | 28 | bool match_fq_mul_conv_bias_same_types(const std::shared_ptr<const ov::Node>& node, FQMulAddPattern pattern) { |
61 | 29 | auto convMulAdd_conv = wrap_type<ov::op::v1::Convolution>(); |
62 | 30 | auto convMulAdd_mul = wrap_type<ov::op::v1::Multiply>({convMulAdd_conv, any_input()}); |
|
0 commit comments