@@ -103,25 +103,32 @@ void Engine::writeBlock(const SymbolTable &symbols,
103103 jitk::Scope scope (symbols, parent_scope, local_tmps, scalar_replaced_reduction_outputs, srio);
104104
105105 // Write temporary and scalar replaced array declarations
106- vector<const bh_view *> scalar_replaced_to_write_back;
106+ vector<pair< const bh_view *, int > > scalar_replaced_to_write_back; // Pair of the view and hidden_axis
107107 for (const jitk::Block &block: kernel._block_list ) {
108108 if (block.isInstr ()) {
109109 const jitk::InstrPtr &instr = block.getInstr ();
110- for (const bh_view &view: instr->getViews ()) {
110+ for (size_t o = 0 ; o < instr->operand .size (); ++o) {
111+ const bh_view &view = instr->operand [o];
111112 if (not scope.isDeclared (view)) {
112113 if (scope.isTmp (view.base )) {
113114 util::spaces (out, 8 + kernel.rank * 4 );
114115 scope.writeDeclaration (view, writeType (view.base ->type ), out);
115116 out << " \n " ;
116117 } else if (scope.isScalarReplaced (view)) {
118+ // If 'instr' is a reduction we have to ignore the reduced axis when declaring the output
119+ // array (but only if we are reducing to a non-scalar).
120+ int hidden_axis = BH_MAXDIM; // Note, `BH_MAXDIM` means on hidden axis
121+ if (o == 0 and bh_opcode_is_reduction (instr->opcode ) and instr->operand [1 ].ndim > 1 ) {
122+ hidden_axis = instr->sweep_axis ();
123+ }
117124 util::spaces (out, 8 + kernel.rank * 4 );
118125 scope.writeDeclaration (view, writeType (view.base ->type ), out);
119126 out << " " << scope.getName (view) << " = a" << symbols.baseID (view.base );
120- write_array_subscription (scope, view, out);
127+ write_array_subscription (scope, view, out, false , hidden_axis );
121128 out << " ;" ;
122129 out << " \n " ;
123130 if (scope.isScalarReplaced_RW (view)) {
124- scalar_replaced_to_write_back.push_back (&view);
131+ scalar_replaced_to_write_back.emplace_back (&view, hidden_axis );
125132 }
126133 }
127134 }
@@ -191,11 +198,13 @@ void Engine::writeBlock(const SymbolTable &symbols,
191198 }
192199 }
193200
194- // Let's copy the scalar replaced reduction outputs back to the original array
195- for (const bh_view *view: scalar_replaced_to_write_back) {
201+ // Let's copy the scalar replaced back to the original array
202+ for (const auto view_and_hidden_axis: scalar_replaced_to_write_back) {
203+ const bh_view *view = view_and_hidden_axis.first ;
204+ const int hidden_axis = view_and_hidden_axis.second ;
196205 util::spaces (out, 8 + kernel.rank * 4 );
197206 out << " a" << symbols.baseID (view->base );
198- write_array_subscription (scope, *view, out, true );
207+ write_array_subscription (scope, *view, out, true , hidden_axis );
199208 out << " = " ;
200209 scope.getName (*view, out);
201210 out << " ;\n " ;
0 commit comments