Skip to content

Commit f3e482d

Browse files
committed
symbol+scope:: now viewID ignores one-sized dimensions
1 parent 0890fa4 commit f3e482d

File tree

4 files changed

+74
-13
lines changed

4 files changed

+74
-13
lines changed

core/jitk/engines/engine.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,25 +103,32 @@ void Engine::writeBlock(const SymbolTable &symbols,
103103
jitk::Scope scope(symbols, parent_scope, local_tmps, scalar_replaced_reduction_outputs, srio);
104104

105105
// Write temporary and scalar replaced array declarations
106-
vector<const bh_view *> scalar_replaced_to_write_back;
106+
vector<pair<const bh_view *, int> > scalar_replaced_to_write_back; // Pair of the view and hidden_axis
107107
for (const jitk::Block &block: kernel._block_list) {
108108
if (block.isInstr()) {
109109
const jitk::InstrPtr &instr = block.getInstr();
110-
for (const bh_view &view: instr->getViews()) {
110+
for (size_t o = 0; o < instr->operand.size(); ++o) {
111+
const bh_view &view = instr->operand[o];
111112
if (not scope.isDeclared(view)) {
112113
if (scope.isTmp(view.base)) {
113114
util::spaces(out, 8 + kernel.rank * 4);
114115
scope.writeDeclaration(view, writeType(view.base->type), out);
115116
out << "\n";
116117
} else if (scope.isScalarReplaced(view)) {
118+
// If 'instr' is a reduction we have to ignore the reduced axis when declaring the output
119+
// array (but only if we are reducing to a non-scalar).
120+
int hidden_axis = BH_MAXDIM; // Note, `BH_MAXDIM` means on hidden axis
121+
if (o == 0 and bh_opcode_is_reduction(instr->opcode) and instr->operand[1].ndim > 1) {
122+
hidden_axis = instr->sweep_axis();
123+
}
117124
util::spaces(out, 8 + kernel.rank * 4);
118125
scope.writeDeclaration(view, writeType(view.base->type), out);
119126
out << " " << scope.getName(view) << " = a" << symbols.baseID(view.base);
120-
write_array_subscription(scope, view, out);
127+
write_array_subscription(scope, view, out, false, hidden_axis);
121128
out << ";";
122129
out << "\n";
123130
if (scope.isScalarReplaced_RW(view)) {
124-
scalar_replaced_to_write_back.push_back(&view);
131+
scalar_replaced_to_write_back.emplace_back(&view, hidden_axis);
125132
}
126133
}
127134
}
@@ -191,11 +198,13 @@ void Engine::writeBlock(const SymbolTable &symbols,
191198
}
192199
}
193200

194-
// Let's copy the scalar replaced reduction outputs back to the original array
195-
for (const bh_view *view: scalar_replaced_to_write_back) {
201+
// Let's copy the scalar replaced back to the original array
202+
for (const auto view_and_hidden_axis: scalar_replaced_to_write_back) {
203+
const bh_view *view = view_and_hidden_axis.first;
204+
const int hidden_axis = view_and_hidden_axis.second;
196205
util::spaces(out, 8 + kernel.rank * 4);
197206
out << "a" << symbols.baseID(view->base);
198-
write_array_subscription(scope, *view, out, true);
207+
write_array_subscription(scope, *view, out, true, hidden_axis);
199208
out << " = ";
200209
scope.getName(*view, out);
201210
out << ";\n";

include/jitk/codegen_util.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ void create_directories(const boost::filesystem::path &path);
107107
// This makes the source of the kernels more identical, which improve the code and compile caches.
108108
std::vector<InstrPtr> order_sweep_set(const std::set<InstrPtr> &sweep_set, const SymbolTable &symbols);
109109

110-
111110
// Returns True when `view` is accessing row major style
112111
bool row_major_access(const bh_view &view);
113112

include/jitk/scope.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ class Scope {
3737
const Scope * const parent;
3838
private:
3939
std::set<const bh_base*> _tmps; // Set of temporary arrays
40-
std::set<bh_view> _scalar_replacements_rw; // Set of scalar replaced arrays that both reads and writes
41-
std::set<bh_view> _scalar_replacements_r; // Set of scalar replaced arrays
40+
std::set<bh_view, IgnoreOneDim_less> _scalar_replacements_rw; // Set of scalar replaced arrays that both reads and writes
41+
std::set<bh_view, IgnoreOneDim_less> _scalar_replacements_r; // Set of scalar replaced arrays
4242
std::set<InstrPtr> _omp_atomic; // Set of instructions that should be guarded by OpenMP atomic
4343
std::set<InstrPtr> _omp_critical; // Set of instructions that should be guarded by OpenMP critical
4444
std::set<bh_base*> _declared_base; // Set of bases that have been locally declared (e.g. a temporary variable)
45-
std::set<bh_view> _declared_view; // Set of views that have been locally declared (e.g. a temporary variable)
45+
std::set<bh_view, IgnoreOneDim_less> _declared_view; // Set of views that have been locally declared (e.g. scalar replaced variable)
4646
std::set<bh_view, OffsetAndStrides_less> _declared_idx; // Set of indexes that have been locally declared
4747
public:
4848
template<typename T1, typename T2>

include/jitk/symbol_table.hpp

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,70 @@ struct OffsetAndStrides_less {
5656

5757
// Compare class for the constant_map
5858
struct Constant_less {
59-
// This compare tje 'origin_id' member of the instructions
59+
// This compare the 'origin_id' member of the instructions
6060
bool operator() (const InstrPtr &i1, const InstrPtr& i2) const {
6161
return i1->origin_id < i2->origin_id;
6262
}
6363
};
6464

65+
// Compare class for the viewID sets and maps
66+
struct IgnoreOneDim_less {
67+
BhIntVec get_shape_where_shape_is_greater_than_one(const bh_view &view) const {
68+
BhIntVec ret;
69+
for (int64_t i = 0; i < view.ndim; ++i) {
70+
if (view.shape[i] > 1) {
71+
ret.push_back(view.shape[i]);
72+
}
73+
}
74+
return ret;
75+
}
76+
77+
BhIntVec get_stride_where_shape_is_greater_than_one(const bh_view &view) const {
78+
BhIntVec ret;
79+
for (int64_t i = 0; i < view.ndim; ++i) {
80+
if (view.shape[i] > 1) {
81+
ret.push_back(view.stride[i]);
82+
}
83+
}
84+
return ret;
85+
}
86+
87+
// This compare is the same as view compare ('v1 < v2') but ignoring their bases and zero or one-sized dimensions
88+
bool operator() (const bh_view& v1, const bh_view& v2) const {
89+
if (v1.base < v2.base) return true;
90+
if (v2.base < v1.base) return false;
91+
if (v1.start < v2.start) return true;
92+
if (v2.start < v1.start) return false;
93+
94+
auto v1_shape = get_shape_where_shape_is_greater_than_one(v1);
95+
auto v2_shape = get_shape_where_shape_is_greater_than_one(v2);
96+
if (v1_shape.size() < v2_shape.size()) return true;
97+
if (v2_shape.size() < v1_shape.size()) return false;
98+
99+
auto v1_stride = get_stride_where_shape_is_greater_than_one(v1);
100+
auto v2_stride = get_stride_where_shape_is_greater_than_one(v2);
101+
assert(v1_shape.size() == v1_stride.size());
102+
assert(v2_shape.size() == v2_stride.size());
103+
104+
for (size_t i=0; i < v1_shape.size(); ++i) {
105+
if (v1_stride[i] < v2_stride[i]) return true;
106+
if (v2_stride[i] < v1_stride[i]) return false;
107+
if (v1_shape[i] < v2_shape[i]) return true;
108+
if (v2_shape[i] < v1_shape[i]) return false;
109+
}
110+
return false;
111+
}
112+
bool operator() (const bh_view* v1, const bh_view* v2) const {
113+
return (*this)(*v1, *v2);
114+
}
115+
};
116+
117+
65118
// The SymbolTable class contains all array meta date needed for a JIT kernel.
66119
class SymbolTable {
67120
private:
68121
std::map<const bh_base*, size_t> _base_map; // Mapping a base to its ID
69-
std::map<bh_view, size_t> _view_map; // Mapping a view to its ID
122+
std::map<bh_view, size_t, IgnoreOneDim_less> _view_map; // Mapping a view to its ID
70123
std::map<bh_view, size_t, OffsetAndStrides_less> _idx_map; // Mapping a index (of an array) to its ID
71124
std::map<bh_view, size_t, OffsetAndStrides_less> _offset_strides_map; // Mapping a offset-and-strides to its ID
72125
std::vector<const bh_view*> _offset_stride_views; // Vector of all offset-and-stride views

0 commit comments

Comments
 (0)