Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
#include "shamrock/solvergraph/FieldRefs.hpp"
#include "shamrock/solvergraph/Indexes.hpp"
#include "shamrock/solvergraph/ScalarsEdge.hpp"
#include "shamrock/solvergraph/SolverGraph.hpp"
#include "shamsys/legacy/log.hpp"

// GSPH solvergraph edges
#include "shammodels/gsph/solvergraph/GhostCacheEdge.hpp"
#include "shamtree/CompressedLeafBVH.hpp"
#include "shamtree/KarrasRadixTreeField.hpp"
#include "shamtree/RadixTree.hpp"
Expand Down Expand Up @@ -70,6 +74,8 @@ namespace shammodels::gsph {

using RTree = shamtree::CompressedLeafBVH<Tmorton, Tvec, 3>;

shamrock::solvergraph::SolverGraph solver_graph;

/// Particle counts per patch
std::shared_ptr<shamrock::solvergraph::Indexes<u32>> part_counts;
std::shared_ptr<shamrock::solvergraph::Indexes<u32>> part_counts_with_ghost;
Expand All @@ -89,7 +95,9 @@ namespace shammodels::gsph {

/// Ghost handler for boundary particles
Component<GhostHandle> ghost_handler;
Component<GhostHandleCache> ghost_patch_cache;

/// Ghost interface cache - managed via SolverGraph
std::shared_ptr<solvergraph::GhostCacheEdge<Tvec>> ghost_cache;

/// Merged position-h data for neighbor search
Component<shambase::DistributedData<shamrock::patch::PatchDataLayer>> merged_xyzh;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// -------------------------------------------------------//
//
// SHAMROCK code for hydrodynamics
// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
//
// -------------------------------------------------------//

#pragma once

/**
* @file GhostCacheEdge.hpp
* @author Guo Yansong (guo.yansong.ngy@gmail.com)
* @brief SolverGraph edge for GSPH ghost cache
*/

#include "shambase/memory.hpp"
#include "shammodels/gsph/modules/GSPHGhostHandler.hpp"
#include "shamrock/solvergraph/IEdgeNamed.hpp"
#include <optional>

namespace shammodels::gsph::solvergraph {

/// SolverGraph edge for ghost interface cache
template<class Tvec>
class GhostCacheEdge : public shamrock::solvergraph::IEdgeNamed {
public:
using IEdgeNamed::IEdgeNamed;
using GhostHandle = GSPHGhostHandler<Tvec>;
using CacheMap = typename GhostHandle::CacheMap;

std::optional<CacheMap> cache;

CacheMap &get() {
if (!cache.has_value()) {
shambase::throw_with_loc<std::runtime_error>("GhostCache not set");
}
return cache.value();
}

const CacheMap &get() const {
if (!cache.has_value()) {
shambase::throw_with_loc<std::runtime_error>("GhostCache not set");
}
return cache.value();
}

bool has_value() const { return cache.has_value(); }
void set(CacheMap &&c) { cache = std::move(c); }
inline virtual void free_alloc() override { cache.reset(); }
};

} // namespace shammodels::gsph::solvergraph
26 changes: 16 additions & 10 deletions src/shammodels/gsph/src/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ void shammodels::gsph::Solver<Tvec, Kern>::init_solver_graph() {
storage.neigh_cache
= std::make_shared<shammodels::sph::solvergraph::NeighCache>(edges::neigh_cache, "neigh");

// Register ghost cache edge for dependency tracking
storage.ghost_cache = storage.solver_graph.register_edge(
"ghost_cache",
solvergraph::GhostCacheEdge<Tvec>("ghost_cache", "\\mathcal{C}_{\\rm ghost}"));
Comment on lines +90 to +93

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is good to register the ghost cache edge for dependency tracking. However, consider adding a comment explaining why dependency tracking is important in this context. This will help future developers understand the purpose of this registration.

Also, the name ghost_cache is repeated in the register_edge call. It would be beneficial to define this name as a constant to avoid potential typos and improve maintainability.

Suggested change
// Register ghost cache edge for dependency tracking
storage.ghost_cache = storage.solver_graph.register_edge(
"ghost_cache",
solvergraph::GhostCacheEdge<Tvec>("ghost_cache", "\\mathcal{C}_{\\rm ghost}"));
// Define a constant for the ghost cache edge name to avoid typos and improve maintainability
constexpr char const* ghost_cache_edge_name = "ghost_cache";
// Register ghost cache edge for dependency tracking
// Dependency tracking ensures that the ghost cache is valid before being used in computations
storage.ghost_cache = storage.solver_graph.register_edge(
ghost_cache_edge_name,
solvergraph::GhostCacheEdge<Tvec>(ghost_cache_edge_name, "\\mathcal{C}_{\\rm ghost}"));


storage.omega = std::make_shared<shamrock::solvergraph::Field<Tscal>>(1, "omega", "\\Omega");
storage.density = std::make_shared<shamrock::solvergraph::Field<Tscal>>(1, "density", "\\rho");
storage.pressure = std::make_shared<shamrock::solvergraph::Field<Tscal>>(1, "pressure", "P");
Expand Down Expand Up @@ -174,24 +179,25 @@ void shammodels::gsph::Solver<Tvec, Kern>::build_ghost_cache() {
using GSPHUtils = GSPHUtilities<Tvec, Kernel>;
GSPHUtils gsph_utils(scheduler());

storage.ghost_patch_cache.set(gsph_utils.build_interf_cache(
storage.ghost_handler.get(),
storage.serial_patch_tree.get(),
solver_config.htol_up_coarse_cycle));
shambase::get_check_ref(storage.ghost_cache)
.set(gsph_utils.build_interf_cache(
storage.ghost_handler.get(),
storage.serial_patch_tree.get(),
solver_config.htol_up_coarse_cycle));
Comment on lines +182 to +186

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using shambase::get_check_ref here is good for ensuring that the ghost_cache is properly initialized before use. However, consider adding a comment explaining why this check is necessary and what the consequences are if the cache is not properly set. This will help future developers understand the importance of this check.

shambase::get_check_ref(storage.ghost_cache)
    // get_check_ref ensures that the ghost_cache is properly initialized before use.
    // If the ghost_cache is not properly set, it throws an exception to prevent further computation with invalid cache.
    .set(gsph_utils.build_interf_cache(
        storage.ghost_handler.get(),
        storage.serial_patch_tree.get(),
        solver_config.htol_up_coarse_cycle));

}

template<class Tvec, template<class> class Kern>
void shammodels::gsph::Solver<Tvec, Kern>::clear_ghost_cache() {
StackEntry stack_loc{};
storage.ghost_patch_cache.reset();
shambase::get_check_ref(storage.ghost_cache).free_alloc();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment, it would be helpful to add a comment explaining why shambase::get_check_ref is used here and what the consequences are if the cache is not properly set before freeing the allocated cache.

shambase::get_check_ref(storage.ghost_cache)
    // get_check_ref ensures that the ghost_cache is properly initialized before use.
    // If the ghost_cache is not properly set, it throws an exception to prevent attempting to free an invalid cache.
    .free_alloc();

}

template<class Tvec, template<class> class Kern>
void shammodels::gsph::Solver<Tvec, Kern>::merge_position_ghost() {
StackEntry stack_loc{};

storage.merged_xyzh.set(
storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get()));
storage.merged_xyzh.set(storage.ghost_handler.get().build_comm_merge_positions(
shambase::get_check_ref(storage.ghost_cache).get()));

// Get field indices from xyzh_ghost_layout
const u32 ixyz_ghost
Expand Down Expand Up @@ -683,7 +689,7 @@ void shammodels::gsph::Solver<Tvec, Kern>::communicate_merge_ghosts_fields() {

// Build interface data from ghost cache
auto pdat_interf = ghost_handle.template build_interface_native<PatchDataLayer>(
storage.ghost_patch_cache.get(),
shambase::get_check_ref(storage.ghost_cache).get(),
[&](u64 sender, u64, InterfaceBuildInfos binfo, sham::DeviceBuffer<u32> &buf_idx, u32 cnt) {
PatchDataLayer pdat(ghost_layout_ptr);
pdat.reserve(cnt);
Expand All @@ -692,7 +698,7 @@ void shammodels::gsph::Solver<Tvec, Kern>::communicate_merge_ghosts_fields() {

// Populate interface data with field values
ghost_handle.template modify_interface_native<PatchDataLayer>(
storage.ghost_patch_cache.get(),
shambase::get_check_ref(storage.ghost_cache).get(),
pdat_interf,
[&](u64 sender,
u64,
Expand Down Expand Up @@ -733,7 +739,7 @@ void shammodels::gsph::Solver<Tvec, Kern>::communicate_merge_ghosts_fields() {

// Apply velocity offset for periodic boundaries
ghost_handle.template modify_interface_native<PatchDataLayer>(
storage.ghost_patch_cache.get(),
shambase::get_check_ref(storage.ghost_cache).get(),
pdat_interf,
[&](u64 sender,
u64,
Expand Down