Skip to content

Commit 380ac70

Browse files
committed
Avoid calculating basis twice for 2 fields, this version in USNCCM18 performance
1 parent d89991e commit 380ac70

File tree

2 files changed

+76
-3
lines changed

2 files changed

+76
-3
lines changed

src/pmpo_MPMesh.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,16 @@ void MPMesh::push_ahead(){
326326
Kokkos::Timer timer;
327327
//Latitude Longitude increment at mesh vertices and interpolate to particle position
328328
p_mesh->computeRotLatLonIncr();
329-
sphericalInterpolation<MeshF_RotLatLonIncr>(*this);
329+
330+
//sphericalInterpolation<MeshF_RotLatLonIncr>(*this);
330331
//Interploate mesh velocity increments to particle positions
331332
//Note that the basis fucntions are created twice and so need to avoid redeundant clualtions
332333
//Tried template lists Template_Type... maybe better option available
333-
Kokkos::fence();
334-
sphericalInterpolation<MeshF_OnSurfVeloIncr>(*this);
334+
//Kokkos::fence();
335+
//sphericalInterpolation<MeshF_OnSurfVeloIncr>(*this);
336+
337+
sphericalInterpolation1(*this);
338+
335339
//Push the MPs
336340
p_MPs->updateRotLatLonAndXYZ2Tgt(p_mesh->getSphereRadius());
337341
pumipic::RecordTime("PolyMPO_interpolateAndPush", timer.seconds());

src/pmpo_wachspressBasis.hpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,5 +358,74 @@ void sphericalInterpolation(MPMesh& mpMesh){
358358
pumipic::RecordTime("PolyMPO_sphericalInterpolation", timer.seconds());
359359
}
360360

361+
362+
inline void sphericalInterpolation1(MPMesh& mpMesh){
363+
Kokkos::Timer timer;
364+
auto p_mesh = mpMesh.p_mesh;
365+
auto vtxCoords = p_mesh->getMeshField<polyMPO::MeshF_VtxCoords>();
366+
int numVtxs = p_mesh->getNumVertices();
367+
auto elm2VtxConn = p_mesh->getElm2VtxConn();
368+
369+
auto p_MPs = mpMesh.p_MPs;
370+
auto MPsPosition = p_MPs->getPositions();
371+
double radius = p_mesh->getSphereRadius();
372+
PMT_ALWAYS_ASSERT(radius > 0);
373+
374+
constexpr MeshFieldIndex meshFieldIndex1 = polyMPO::MeshF_RotLatLonIncr;
375+
constexpr MeshFieldIndex meshFieldIndex2 = polyMPO::MeshF_OnSurfVeloIncr;
376+
377+
auto meshField1 = p_mesh->getMeshField<meshFieldIndex1>();
378+
auto meshField2 = p_mesh->getMeshField<meshFieldIndex2>();
379+
380+
constexpr MaterialPointSlice mpfIndex1 = meshFieldIndexToMPSlice<meshFieldIndex1>;
381+
constexpr MaterialPointSlice mpfIndex2 = meshFieldIndexToMPSlice<meshFieldIndex2>;
382+
383+
const int numEntries1 = mpSliceToNumEntries<mpfIndex1>();
384+
const int numEntries2 = mpSliceToNumEntries<mpfIndex2>();
385+
386+
auto mpField1 = p_MPs->getData<mpfIndex1>();
387+
auto mpField2 = p_MPs->getData<mpfIndex2>();
388+
389+
auto interpolation = PS_LAMBDA(const int& elm, const int& mp, const int& mask) {
390+
if(mask) {
391+
Vec3d position3d(MPsPosition(mp, 0), MPsPosition(mp, 1), MPsPosition(mp, 2));
392+
Vec3d v3d[maxVtxsPerElm + 1];
393+
int numVtx = elm2VtxConn(elm, 0);
394+
for (int i = 1; i <= numVtx; i++) {
395+
v3d[i-1][0] = vtxCoords(elm2VtxConn(elm, i) - 1, 0);
396+
v3d[i-1][1] = vtxCoords(elm2VtxConn(elm, i) - 1, 1);
397+
v3d[i-1][2] = vtxCoords(elm2VtxConn(elm, i) - 1, 2);
398+
}
399+
v3d[numVtx][0] = vtxCoords(elm2VtxConn(elm,1)-1,0);
400+
v3d[numVtx][1] = vtxCoords(elm2VtxConn(elm,1)-1,1);
401+
v3d[numVtx][2] = vtxCoords(elm2VtxConn(elm,1)-1,2);
402+
403+
double basisByArea3d[maxVtxsPerElm] = {0.0};
404+
initArray(basisByArea3d, maxVtxsPerElm, 0.0);
405+
406+
getBasisByAreaGblFormSpherical(position3d, numVtx, v3d, radius, basisByArea3d);
407+
408+
for(int entry=0; entry<numEntries1; entry++){
409+
double mpValue = 0.0;
410+
for(int i=1; i<= numVtx; i++){
411+
mpValue += meshField1(elm2VtxConn(elm,i)-1,entry)*basisByArea3d[i-1];
412+
}
413+
mpField1(mp,entry) = mpValue;
414+
}
415+
416+
for(int entry=0; entry<numEntries2; entry++){
417+
double mpValue = 0.0;
418+
for(int i=1; i<= numVtx; i++){
419+
mpValue += meshField2(elm2VtxConn(elm,i)-1,entry)*basisByArea3d[i-1];
420+
}
421+
mpField2(mp,entry) = mpValue;
422+
}
423+
}
424+
};
425+
p_MPs->parallel_for(interpolation, "sphericalInterpolationMultiField");
426+
pumipic::RecordTime("PolyMPO_sphericalInterpolation1", timer.seconds());
427+
}
428+
429+
361430
} //namespace polyMPO end
362431
#endif

0 commit comments

Comments
 (0)