Skip to content

Commit d50b976

Browse files
authored
Merge pull request #246 from Sichao25/yus/functor
Replace lambda with functor
2 parents e745644 + 8b158c3 commit d50b976

File tree

1 file changed

+176
-72
lines changed

1 file changed

+176
-72
lines changed

src/pcms/interpolator/interpolation_base.cpp

Lines changed: 176 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,85 @@ void adapt_radii(unsigned min_req_supports, unsigned max_allowed_supports,
144144
Kokkos::fence();
145145
}
146146

147+
struct ScanSupportIdxFunctor
148+
{
149+
Omega_h::Write<Omega_h::LO> support_ptr_l;
150+
Omega_h::Write<Omega_h::LO> num_supports;
151+
152+
ScanSupportIdxFunctor(Omega_h::Write<Omega_h::LO> support_ptr_l_in,
153+
Omega_h::Write<Omega_h::LO> num_supports_in)
154+
: support_ptr_l(support_ptr_l_in), num_supports(num_supports_in)
155+
{
156+
}
157+
158+
KOKKOS_INLINE_FUNCTION
159+
void operator()(const int& i, unsigned& update, const bool final) const
160+
{
161+
update += num_supports[i];
162+
if (final) {
163+
support_ptr_l[i + 1] = update;
164+
}
165+
}
166+
};
167+
168+
struct FillSupportIdxFunctor
169+
{
170+
const int dim;
171+
const Omega_h::LO n_sources;
172+
const Omega_h::Write<Omega_h::LO> support_ptr_l;
173+
const Omega_h::Write<Omega_h::LO> supports_idx_l;
174+
const Omega_h::Reals target_coords_l;
175+
const Omega_h::Reals source_coords_l;
176+
const Omega_h::Write<Omega_h::Real> radii2_l;
177+
178+
FillSupportIdxFunctor(int dim_in, Omega_h::LO n_sources_in,
179+
Omega_h::Write<Omega_h::LO> support_ptr_l_in,
180+
Omega_h::Write<Omega_h::LO> supports_idx_l_in,
181+
Omega_h::Reals target_coords_l_in,
182+
Omega_h::Reals source_coords_l_in,
183+
Omega_h::Write<Omega_h::Real> radii2_l_in)
184+
: dim(dim_in),
185+
n_sources(n_sources_in),
186+
support_ptr_l(support_ptr_l_in),
187+
supports_idx_l(supports_idx_l_in),
188+
target_coords_l(target_coords_l_in),
189+
source_coords_l(source_coords_l_in),
190+
radii2_l(radii2_l_in)
191+
{
192+
}
193+
194+
KOKKOS_INLINE_FUNCTION
195+
void operator()(const int& target_id) const
196+
{
197+
auto target_radius2 = radii2_l[target_id];
198+
auto target_coord = Omega_h::Vector<3>{0, 0, 0};
199+
for (int d = 0; d < dim; ++d) {
200+
target_coord[d] = target_coords_l[target_id * dim + d];
201+
}
202+
203+
auto start_ptr = support_ptr_l[target_id];
204+
auto end_ptr = support_ptr_l[target_id + 1];
205+
206+
for (int source_id = 0; source_id < n_sources; source_id++) {
207+
auto source_coord = Omega_h::Vector<3>{0, 0, 0};
208+
for (int d = 0; d < dim; ++d) {
209+
source_coord[d] = source_coords_l[source_id * dim + d];
210+
}
211+
auto dist2 =
212+
pointDistanceSquared(source_coord[0], source_coord[1], source_coord[2],
213+
target_coord[0], target_coord[1], target_coord[2]);
214+
if (dist2 <= target_radius2) {
215+
supports_idx_l[start_ptr] = source_id;
216+
start_ptr++;
217+
OMEGA_H_CHECK_PRINTF(
218+
start_ptr <= end_ptr,
219+
"Support index out of bounds:start %d end %d target_id %d\n",
220+
start_ptr, end_ptr, target_id);
221+
}
222+
}
223+
}
224+
};
225+
147226
// TODO Merge this with distance based search like NeighborSearch for
148227
// consistency
149228
void MLSPointCloudInterpolation::fill_support_structure(
@@ -154,15 +233,8 @@ void MLSPointCloudInterpolation::fill_support_structure(
154233
auto support_ptr_l = Omega_h::Write<Omega_h::LO>(n_targets_ + 1, 0);
155234
unsigned total_supports = 0;
156235
Kokkos::fence();
157-
Kokkos::parallel_scan(
158-
"scan", n_targets_,
159-
KOKKOS_LAMBDA(const int& i, unsigned& update, const bool final) {
160-
update += num_supports[i];
161-
if (final) {
162-
support_ptr_l[i + 1] = update;
163-
}
164-
},
165-
total_supports);
236+
ScanSupportIdxFunctor scanFunctor(support_ptr_l, num_supports);
237+
Kokkos::parallel_scan("scan", n_targets_, scanFunctor, total_supports);
166238

167239
pcms::printInfo("Total supports found: %d\n", total_supports);
168240
// resize the support index
@@ -174,35 +246,10 @@ void MLSPointCloudInterpolation::fill_support_structure(
174246
const auto target_coords_l = target_coords_;
175247
const auto source_coords_l = source_coords_;
176248
const auto n_sources = n_sources_;
177-
Kokkos::parallel_for(
178-
"fill support index", n_targets_, KOKKOS_LAMBDA(const int& target_id) {
179-
auto target_radius2 = radii2_l[target_id];
180-
auto target_coord = Omega_h::Vector<3>{0, 0, 0};
181-
for (int d = 0; d < dim; ++d) {
182-
target_coord[d] = target_coords_l[target_id * dim + d];
183-
}
184-
185-
auto start_ptr = support_ptr_l[target_id];
186-
auto end_ptr = support_ptr_l[target_id + 1];
187-
188-
for (int source_id = 0; source_id < n_sources; source_id++) {
189-
auto source_coord = Omega_h::Vector<3>{0, 0, 0};
190-
for (int d = 0; d < dim; ++d) {
191-
source_coord[d] = source_coords_l[source_id * dim + d];
192-
}
193-
auto dist2 = pointDistanceSquared(source_coord[0], source_coord[1],
194-
source_coord[2], target_coord[0],
195-
target_coord[1], target_coord[2]);
196-
if (dist2 <= target_radius2) {
197-
support_idx_l[start_ptr] = source_id;
198-
start_ptr++;
199-
OMEGA_H_CHECK_PRINTF(
200-
start_ptr <= end_ptr,
201-
"Support index out of bounds:start %d end %d target_id %d\n",
202-
start_ptr, end_ptr, target_id);
203-
}
204-
}
205-
});
249+
FillSupportIdxFunctor fillSupportIdxFunctor(dim, n_sources, support_ptr_l,
250+
support_idx_l, target_coords_l,
251+
source_coords_l, radii2_l);
252+
Kokkos::parallel_for("fill support index", n_targets_, fillSupportIdxFunctor);
206253
Kokkos::fence();
207254

208255
// copy the support index to the supports
@@ -211,6 +258,54 @@ void MLSPointCloudInterpolation::fill_support_structure(
211258
supports_.supports_idx = Omega_h::LOs(support_idx_l);
212259
}
213260

261+
struct NSquareSearchFunctor
262+
{
263+
const int dim;
264+
const Omega_h::LO n_sources;
265+
const Omega_h::Reals target_coords_l;
266+
const Omega_h::Reals source_coords_l;
267+
const Omega_h::Write<Omega_h::Real> radii2_l;
268+
const Omega_h::Write<Omega_h::LO> num_supports_l;
269+
270+
NSquareSearchFunctor(int dim_in, Omega_h::LO n_sources_in,
271+
Omega_h::Reals target_coords_l_in,
272+
Omega_h::Reals source_coords_l_in,
273+
Omega_h::Write<Omega_h::Real> radii2_l_in,
274+
Omega_h::Write<Omega_h::LO> num_supports_l_in)
275+
: dim(dim_in),
276+
n_sources(n_sources_in),
277+
target_coords_l(target_coords_l_in),
278+
source_coords_l(source_coords_l_in),
279+
radii2_l(radii2_l_in),
280+
num_supports_l(num_supports_l_in)
281+
{
282+
}
283+
284+
KOKKOS_INLINE_FUNCTION
285+
void operator()(const int& target_id) const
286+
{
287+
auto target_coord = Omega_h::Vector<3>{0, 0, 0};
288+
for (int d = 0; d < dim; ++d) {
289+
target_coord[d] = target_coords_l[target_id * dim + d];
290+
}
291+
auto target_radius2 = radii2_l[target_id];
292+
293+
// TODO: parallel with kokkos parallel_for
294+
for (int i = 0; i < n_sources; i++) {
295+
auto source_coord = Omega_h::Vector<3>{0, 0, 0};
296+
for (int d = 0; d < dim; ++d) {
297+
source_coord[d] = source_coords_l[i * dim + d];
298+
}
299+
auto dist2 =
300+
pointDistanceSquared(source_coord[0], source_coord[1], source_coord[2],
301+
target_coord[0], target_coord[1], target_coord[2]);
302+
if (dist2 <= target_radius2) {
303+
num_supports_l[target_id]++; // only one thread is updating
304+
}
305+
}
306+
}
307+
};
308+
214309
// use uniform grid based point search when available
215310
void MLSPointCloudInterpolation::distance_based_pointcloud_search(
216311
Omega_h::Write<Omega_h::Real> radii2_l,
@@ -222,51 +317,60 @@ void MLSPointCloudInterpolation::distance_based_pointcloud_search(
222317
const auto source_coords_l = source_coords_;
223318
const auto n_sources = n_sources_;
224319
const auto n_targets = n_targets_;
225-
Kokkos::parallel_for(
226-
"n^2 search", n_targets, KOKKOS_LAMBDA(const int& target_id) {
227-
auto target_coord = Omega_h::Vector<3>{0, 0, 0};
228-
for (int d = 0; d < dim; ++d) {
229-
target_coord[d] = target_coords_l[target_id * dim + d];
230-
}
231-
auto target_radius2 = radii2_l[target_id];
232-
233-
// TODO: parallel with kokkos parallel_for
234-
for (int i = 0; i < n_sources; ++i) {
235-
auto source_coord = Omega_h::Vector<3>{0, 0, 0};
236-
for (int d = 0; d < dim; ++d) {
237-
source_coord[d] = source_coords_l[i * dim + d];
238-
}
239-
auto dist2 = pointDistanceSquared(source_coord[0], source_coord[1],
240-
source_coord[2], target_coord[0],
241-
target_coord[1], target_coord[2]);
242-
if (dist2 <= target_radius2) {
243-
num_supports[target_id]++; // only one thread is updating
244-
}
245-
}
246-
});
320+
NSquareSearchFunctor nSquareSearchFunctor(
321+
dim, n_sources, target_coords_l, source_coords_l, radii2_l, num_supports);
322+
Kokkos::parallel_for("n^2 search", n_targets, nSquareSearchFunctor);
247323
Kokkos::fence();
248324
}
249325

326+
struct PrintTargetPointsFunctor
327+
{
328+
const Omega_h::Reals target_coords_l;
329+
330+
PrintTargetPointsFunctor(Omega_h::Reals target_coords_l_in, int dim_in)
331+
: target_coords_l(target_coords_l_in)
332+
{
333+
}
334+
335+
KOKKOS_INLINE_FUNCTION
336+
void operator()(const int& i) const
337+
{
338+
pcms::printDebugInfo("Target Point %d: (%f, %f)\n", i,
339+
target_coords_l[i * 2 + 0],
340+
target_coords_l[i * 2 + 1]);
341+
}
342+
};
343+
344+
struct PrintSourcePointsFunctor
345+
{
346+
const Omega_h::Reals source_coords_l;
347+
348+
PrintSourcePointsFunctor(Omega_h::Reals source_coords_l_in)
349+
: source_coords_l(source_coords_l_in)
350+
{
351+
}
352+
353+
KOKKOS_INLINE_FUNCTION
354+
void operator()(const int& i) const
355+
{
356+
pcms::printDebugInfo("Source Point %d: (%f, %f)\n", i,
357+
source_coords_l[i * 2 + 0],
358+
source_coords_l[i * 2 + 1]);
359+
}
360+
};
361+
250362
void MLSPointCloudInterpolation::find_supports(unsigned min_req_supports,
251363
unsigned max_allowed_supports,
252364
unsigned max_count)
253365
{
254366
pcms::printDebugInfo("First 10 Target Points with %d points:\n", n_targets_);
255367
const auto target_coords_l = target_coords_;
256368
const auto source_coords_l = source_coords_;
257-
Omega_h::parallel_for(
258-
"print target points", 10, OMEGA_H_LAMBDA(const int& i) {
259-
pcms::printDebugInfo("Target Point %d: (%f, %f)\n", i,
260-
target_coords_l[i * 2 + 0],
261-
target_coords_l[i * 2 + 1]);
262-
});
369+
PrintTargetPointsFunctor printTargetPointsFunctor(target_coords_l, dim_);
370+
Kokkos::parallel_for("print target points", 10, printTargetPointsFunctor);
263371
pcms::printDebugInfo("First 10 Source Points with %d points:\n", n_sources_);
264-
Omega_h::parallel_for(
265-
"print source points", 10, OMEGA_H_LAMBDA(const int& i) {
266-
pcms::printDebugInfo("Source Point %d: (%f, %f)\n", i,
267-
source_coords_l[i * 2 + 0],
268-
source_coords_l[i * 2 + 1]);
269-
});
372+
PrintSourcePointsFunctor printSourcePointsFunctor(source_coords_l);
373+
Kokkos::parallel_for("print source points", 10, printSourcePointsFunctor);
270374

271375
auto radii2_l = Omega_h::Write<Omega_h::Real>(n_targets_, radius_);
272376
auto num_supports = Omega_h::Write<Omega_h::LO>(n_targets_, 0);

0 commit comments

Comments
 (0)