-
Notifications
You must be signed in to change notification settings - Fork 16.1k
[mlir][spirv] Add 6 Element Binary operators to TOSA Ext Inst Set #179627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -62,6 +62,19 @@ class SPIRV_TosaOpWithComplexResult<string mnemonic, int opcode, list<Trait> tra | |||||||
| }]; | ||||||||
| } | ||||||||
|
|
||||||||
| class SPIRV_TosaElementwiseBinaryOp<string mnemonic, int opcode, list<Trait> traits = []> : | ||||||||
| SPIRV_TosaOpWithResult<mnemonic, opcode, traits> { | ||||||||
|
|
||||||||
| let extraClassDeclaration = extraBaseClassDeclaration#[{ | ||||||||
| ::mlir::spirv::TensorArmType getInput1Type() { | ||||||||
| return cast<::mlir::spirv::TensorArmType>(getInput1().getType()); | ||||||||
| } | ||||||||
| ::mlir::spirv::TensorArmType getInput2Type() { | ||||||||
| return cast<::mlir::spirv::TensorArmType>(getInput2().getType()); | ||||||||
| } | ||||||||
| }]; | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| def SPIRV_TosaArgMaxOp : SPIRV_TosaOpWithResult<"ArgMax", 0, [Pure, | ||||||||
| OutputRankIsInputRankMinusOne<"input", "output">, | ||||||||
|
|
@@ -863,4 +876,229 @@ def SPIRV_TosaTanhOp : SPIRV_TosaOpWithResult<"Tanh", 13, [Pure, | |||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| def SPIRV_TosaAddOp : SPIRV_TosaElementwiseBinaryOp<"Add", 14, [Pure, | ||||||||
| AllElementTypesMatch<["input1", "input2", "output"]>, | ||||||||
| AllRanksMatch<["input1", "input2"]>]> { | ||||||||
|
Comment on lines
+880
to
+881
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also move these traits to the base class? |
||||||||
| let summary = "Addition operator."; | ||||||||
|
|
||||||||
| let description = [{ | ||||||||
| Elementwise Addition of input1 and input2. Axis of size 1 will be broadcast, | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need a verifier that checks that the types must match in shape modulo size 1 dims?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, should we have one? We should also take into account to ignore dynamic shape when comparing.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be useful to verify this |
||||||||
| as necessary. Rank of input tensors must match. | ||||||||
|
|
||||||||
| References: | ||||||||
| * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_add | ||||||||
| * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_add | ||||||||
|
|
||||||||
| #### Example: | ||||||||
| ```mlir | ||||||||
| %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<4x7x3x10xi32>, !spirv.arm.tensor<4x7x3x1xi32> -> !spirv.arm.tensor<4x7x3x10xi32> | ||||||||
| %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<26x37x18xf16>, !spirv.arm.tensor<1x37x18xf16> -> !spirv.arm.tensor<26x37x18xf16> | ||||||||
| ``` | ||||||||
| }]; | ||||||||
|
|
||||||||
| let arguments = (ins | ||||||||
| SPIRV_TosaNumerical_TensorArm: $input1, | ||||||||
| SPIRV_TosaNumerical_TensorArm: $input2 | ||||||||
| ); | ||||||||
|
|
||||||||
| let results = (outs | ||||||||
| SPIRV_TosaNumerical_TensorArm: $output | ||||||||
| ); | ||||||||
|
|
||||||||
| let assemblyFormat = [{ | ||||||||
| $input1 `,` | ||||||||
| $input2 | ||||||||
|
Comment on lines
+909
to
+910
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| attr-dict `:` type(operands) `->` type(results) | ||||||||
| }]; | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| def SPIRV_TosaArithmeticRightShiftOp : SPIRV_TosaElementwiseBinaryOp<"ArithmeticRightShift", 15, [Pure, | ||||||||
| AllElementTypesMatch<["input1", "input2", "output"]>, | ||||||||
| AllRanksMatch<["input1", "input2"]>]> { | ||||||||
| let summary = "Arithmetic Right Shift."; | ||||||||
|
|
||||||||
| let description = [{ | ||||||||
| Elementwise Arithmetic Right Shift of input1 by the amount specified in | ||||||||
| input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors | ||||||||
| must match. | ||||||||
|
Comment on lines
+922
to
+924
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you specify what happens when the shift amount is greater or equal to the bitwidth? Or by a negative amount.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is undefined behaviour and it is implementation dependent.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should document this in the op description. If this is the case, the op can't have the |
||||||||
|
|
||||||||
| References: | ||||||||
| * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_arithmetic_right_shift | ||||||||
| * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_arithmetic_right_shift | ||||||||
|
|
||||||||
| #### Example: | ||||||||
| ```mlir | ||||||||
| %1 = spirv.Tosa.ArithmeticRightShift round = true, %arg0, %arg1 : !spirv.arm.tensor<1x47x22xi16>, !spirv.arm.tensor<49x47x22xi16> -> !spirv.arm.tensor<49x47x22xi16> | ||||||||
| ``` | ||||||||
| }]; | ||||||||
|
|
||||||||
| let arguments = (ins | ||||||||
| SPIRV_BoolConstAttr: $round, | ||||||||
| SPIRV_TosaInteger_TensorArm: $input1, | ||||||||
| SPIRV_TosaInteger_TensorArm: $input2 | ||||||||
| ); | ||||||||
|
|
||||||||
| let results = (outs | ||||||||
| SPIRV_TosaInteger_TensorArm: $output | ||||||||
| ); | ||||||||
|
|
||||||||
| let assemblyFormat = [{ | ||||||||
| `round` `=` $round `,` | ||||||||
| $input1 `,` | ||||||||
| $input2 | ||||||||
|
Comment on lines
+948
to
+949
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| attr-dict `:` type(operands) `->` type(results) | ||||||||
| }]; | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| def SPIRV_TosaBitwiseAndOp : SPIRV_TosaElementwiseBinaryOp<"BitwiseAnd", 16, [Pure, | ||||||||
| AllElementTypesMatch<["input1", "input2", "output"]>, | ||||||||
| AllRanksMatch<["input1", "input2"]>]> { | ||||||||
| let summary = "Bitwise AND operator."; | ||||||||
|
|
||||||||
| let description = [{ | ||||||||
| Elementwise Bitwise AND of input1 and input2. Axis of size 1 | ||||||||
| will be broadcast as necessary. Rank of input tensors must match. | ||||||||
|
|
||||||||
| References: | ||||||||
| * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_bitwise_and | ||||||||
| * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_bitwise_and | ||||||||
|
|
||||||||
| #### Example: | ||||||||
| ```mlir | ||||||||
| %0 = spirv.Tosa.BitwiseAnd %arg0, %arg1 : !spirv.arm.tensor<4x1x7x12xi16>, !spirv.arm.tensor<4x13x7x12xi16> -> !spirv.arm.tensor<4x13x7x12xi16> | ||||||||
| ``` | ||||||||
| }]; | ||||||||
|
|
||||||||
| let arguments = (ins | ||||||||
| SPIRV_TosaInteger_TensorArm: $input1, | ||||||||
| SPIRV_TosaInteger_TensorArm: $input2 | ||||||||
| ); | ||||||||
|
|
||||||||
| let results = (outs | ||||||||
| SPIRV_TosaInteger_TensorArm: $output | ||||||||
| ); | ||||||||
|
|
||||||||
| let assemblyFormat = [{ | ||||||||
| $input1 `,` | ||||||||
| $input2 | ||||||||
|
Comment on lines
+984
to
+985
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| attr-dict `:` type(operands) `->` type(results) | ||||||||
| }]; | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| def SPIRV_TosaBitwiseOrOp : SPIRV_TosaElementwiseBinaryOp<"BitwiseOr", 17, [Pure, | ||||||||
| AllElementTypesMatch<["input1", "input2", "output"]>, | ||||||||
| AllRanksMatch<["input1", "input2"]>]> { | ||||||||
| let summary = "Bitwise OR operator."; | ||||||||
|
|
||||||||
| let description = [{ | ||||||||
| Elementwise Bitwise OR of input1 and input2. Axis of size 1 will be | ||||||||
| broadcast as necessary. Rank of input tensors must match. | ||||||||
|
|
||||||||
| References: | ||||||||
| * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_bitwise_or | ||||||||
| * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_bitwise_or | ||||||||
|
|
||||||||
| #### Example: | ||||||||
| ```mlir | ||||||||
| %0 = spirv.Tosa.BitwiseOr %arg0, %arg1 : !spirv.arm.tensor<11x30x23xi32>, !spirv.arm.tensor<1x30x23xi32> -> !spirv.arm.tensor<11x30x23xi32> | ||||||||
| ``` | ||||||||
| }]; | ||||||||
|
|
||||||||
| let arguments = (ins | ||||||||
| SPIRV_TosaInteger_TensorArm: $input1, | ||||||||
| SPIRV_TosaInteger_TensorArm: $input2 | ||||||||
| ); | ||||||||
|
|
||||||||
| let results = (outs | ||||||||
| SPIRV_TosaInteger_TensorArm: $output | ||||||||
| ); | ||||||||
|
|
||||||||
| let assemblyFormat = [{ | ||||||||
| $input1 `,` | ||||||||
| $input2 | ||||||||
|
Comment on lines
+1020
to
+1021
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| attr-dict `:` type(operands) `->` type(results) | ||||||||
| }]; | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| def SPIRV_TosaBitwiseXorOp : SPIRV_TosaElementwiseBinaryOp<"BitwiseXor", 18, [Pure, | ||||||||
| AllElementTypesMatch<["input1", "input2", "output"]>, | ||||||||
| AllRanksMatch<["input1", "input2"]>]> { | ||||||||
| let summary = "Bitwise XOR operator."; | ||||||||
|
|
||||||||
| let description = [{ | ||||||||
| Elementwise Bitwise XOR of input1 and input2. Axis of size 1 will be | ||||||||
| broadcast as necessary. Rank of input tensors must match. | ||||||||
|
|
||||||||
| References: | ||||||||
| * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_bitwise_xor | ||||||||
| * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_bitwise_xor | ||||||||
|
|
||||||||
| #### Example: | ||||||||
| ```mlir | ||||||||
| %0 = spirv.Tosa.BitwiseXor %arg0, %arg1 : !spirv.arm.tensor<4x8x13x9xi16>, !spirv.arm.tensor<4x8x1x9xi16> -> !spirv.arm.tensor<4x8x13x9xi16> | ||||||||
| ``` | ||||||||
| }]; | ||||||||
|
|
||||||||
| let arguments = (ins | ||||||||
| SPIRV_TosaInteger_TensorArm: $input1, | ||||||||
| SPIRV_TosaInteger_TensorArm: $input2 | ||||||||
| ); | ||||||||
|
|
||||||||
| let results = (outs | ||||||||
| SPIRV_TosaInteger_TensorArm: $output | ||||||||
| ); | ||||||||
|
|
||||||||
| let assemblyFormat = [{ | ||||||||
| $input1 `,` | ||||||||
| $input2 | ||||||||
|
Comment on lines
+1056
to
+1057
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| attr-dict `:` type(operands) `->` type(results) | ||||||||
| }]; | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| def SPIRV_TosaIntDivOp : SPIRV_TosaElementwiseBinaryOp<"IntDiv", 19, [Pure, | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would int division be pure? It has no memory effects but I would be surprised if it was speculatable. We should have a test for this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes you are right. I was doing too much "happy copying" from TOSA dialect here. (IntDiv is marked Pure there.)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds like a bug on their side |
||||||||
| AllElementTypesMatch<["input1", "input2", "output"]>, | ||||||||
| AllRanksMatch<["input1", "input2"]>]> { | ||||||||
| let summary = "Integer Divide operator."; | ||||||||
|
|
||||||||
| let description = [{ | ||||||||
| Elementwise Integer Divide of input1 by input2. Axis of size 1 will be | ||||||||
| broadcast as necessary. Rank of input tensors must match. | ||||||||
|
|
||||||||
| The result of the divide is truncated towards zero. Expected use is for | ||||||||
| operations on non-scaled integers. Floating point divide should use | ||||||||
| `spirv.Tosa.Reciprocal` and `spirv.Tosa.Mul`. Quantized integer divide | ||||||||
| should use `spirv.Tosa.Table`(for $ 1/x $) and `spirv.Tosa.Mul`. | ||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specify what happens on division by 0.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is undefined behaviour and it is implementation dependent.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should document this in the op description |
||||||||
|
|
||||||||
| References: | ||||||||
| * https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_intdiv | ||||||||
| * https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_intdiv | ||||||||
|
|
||||||||
| #### Example: | ||||||||
| ```mlir | ||||||||
| %0 = spirv.Tosa.IntDiv %arg0, %arg1 : !spirv.arm.tensor<1x65533x1xi32>, !spirv.arm.tensor<2x65533x1xi32> -> !spirv.arm.tensor<2x65533x1xi32> | ||||||||
| ``` | ||||||||
| }]; | ||||||||
|
|
||||||||
| let arguments = (ins | ||||||||
| SPIRV_Int32_TensorArm: $input1, | ||||||||
| SPIRV_Int32_TensorArm: $input2 | ||||||||
| ); | ||||||||
|
|
||||||||
| let results = (outs | ||||||||
| SPIRV_Int32_TensorArm: $output | ||||||||
| ); | ||||||||
|
|
||||||||
| let assemblyFormat = [{ | ||||||||
| $input1 `,` | ||||||||
| $input2 | ||||||||
| attr-dict `:` type(operands) `->` type(results) | ||||||||
| }]; | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
| #endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -431,3 +431,25 @@ spirv.ARM.Graph @clamp_max_val_different_element_type_wrt_input_output(%arg0: !s | |
| %3 = spirv.Tosa.Clamp min_val = -102 : i8, max_val = -100 : i16, nan_mode = <Propagate>, %arg0 : !spirv.arm.tensor<27x44x55xi8> -> !spirv.arm.tensor<27x44x55xi8> | ||
| spirv.ARM.GraphOutputs %3 : !spirv.arm.tensor<27x44x55xi8> | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // spirv.TOSA.Add | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason there are only negative tests for Add? |
||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| spirv.ARM.Graph @add_input_ranks_not_matching(%arg0: !spirv.arm.tensor<6x10x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi32>) { | ||
| // expected-error @+1 {{op failed to verify that all of {input1, input2} have same rank}} | ||
| %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<6x10x6xi32>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi32> | ||
| spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi32> | ||
| } | ||
|
|
||
| spirv.ARM.Graph @add_input_element_types_not_matching(%arg0: !spirv.arm.tensor<6x10x6x6xi16>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi32>) { | ||
| // expected-error @+1 {{op failed to verify that all of {input1, input2, output} have same element type}} | ||
| %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi16>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi32> | ||
| spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi32> | ||
| } | ||
|
|
||
| spirv.ARM.Graph @add_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi16>) { | ||
| // expected-error @+1 {{op failed to verify that all of {input1, input2, output} have same element type}} | ||
| %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi32>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi16> | ||
| spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi16> | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also specify there that element types have to match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could but then we need to have MUL as a special case. Indeed the output of integer multiplication is always i32, even if the inputs are i8 and i16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can have multiple base classes: one for the more general case, one more specialized one for broad coverage