Skip to content

Commit 375eed0

Browse files
Wrap SuperLU_DIST solver options (#4072)
* Moving towards options class + wrap. Looks like a bit of a pain. * Simplify * Add safe casting to int_t on boundary of DOLFINx/SuperLU_DIST * Switch to comparing on max limit, tidy up whitespace. * Revert * Work in progress. * Add option setter for some options * Add a basic setter so that C++ users can just use native SuperLU interface. * Switch to YES and NO, following SuperLU_DIST * Small tweaks. * Tweaks. * Remove verbose flag * Tests run * Final tweaks and comments * Turn off output * Fix. * Fix * Fix * clang-format * Add comment on options. --------- Co-authored-by: Chris Richardson <chris@bpi.cam.ac.uk>
1 parent 473a654 commit 375eed0

File tree

5 files changed

+228
-62
lines changed

5 files changed

+228
-62
lines changed

cpp/dolfinx/la/superlu_dist.cpp

Lines changed: 132 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ extern "C"
1414
#include <superlu_zdefs.h>
1515
}
1616
#include <algorithm>
17+
#include <dolfinx/common/Timer.h>
1718
#include <dolfinx/la/MatrixCSR.h>
1819
#include <dolfinx/la/Vector.h>
20+
#include <ranges>
1921
#include <stdexcept>
2022
#include <vector>
2123

@@ -46,6 +48,25 @@ namespace
4648
template <typename...>
4749
constexpr bool dependent_false_v = false;
4850

51+
template <typename V, typename W>
52+
void option_setter(std::string_view option_name, W& option,
53+
const std::initializer_list<V> values,
54+
std::initializer_list<std::string_view> value_names,
55+
std::string_view value_in)
56+
{
57+
// TODO: Can be done nicely with std::views::zip in C++23.
58+
for (auto i : std::views::iota(std::size_t{0}, value_names.size()))
59+
{
60+
if (value_in == *(value_names.begin() + i))
61+
{
62+
option = *(values.begin() + i);
63+
spdlog::info("Set {} to {}", option_name, value_in);
64+
return;
65+
}
66+
}
67+
throw std::runtime_error("Unsupported value");
68+
}
69+
4970
std::vector<int_t> col_indices(const auto& A)
5071
{
5172
// Local number of non-zeros
@@ -135,15 +156,13 @@ create_supermatrix(const auto& A, auto& rowptr, auto& cols)
135156
} // namespace
136157
//----------------------------------------------------------------------------
137158
template <typename T>
138-
SuperLUDistMatrix<T>::SuperLUDistMatrix(std::shared_ptr<const MatrixCSR<T>> A,
139-
bool verbose)
159+
SuperLUDistMatrix<T>::SuperLUDistMatrix(std::shared_ptr<const MatrixCSR<T>> A)
140160
: _matA(std::move(A)),
141161
_cols(
142162
std::make_unique<SuperLUDistStructs::vec_int_t>(col_indices(*_matA))),
143163
_rowptr(
144164
std::make_unique<SuperLUDistStructs::vec_int_t>(row_indices(*_matA))),
145-
_supermatrix(create_supermatrix<T>(*_matA, *_rowptr, *_cols)),
146-
_verbose(verbose)
165+
_supermatrix(create_supermatrix<T>(*_matA, *_rowptr, *_cols))
147166
{
148167
}
149168
//----------------------------------------------------------------------------
@@ -170,18 +189,32 @@ struct dolfinx::la::SuperLUDistStructs::gridinfo_t : public ::gridinfo_t
170189
{
171190
};
172191
//----------------------------------------------------------------------------
192+
struct dolfinx::la::SuperLUDistStructs::superlu_dist_options_t
193+
: public ::superlu_dist_options_t
194+
{
195+
};
196+
//----------------------------------------------------------------------------
173197
void GridInfoDeleter::operator()(
174198
SuperLUDistStructs::gridinfo_t* gridinfo) const noexcept
175199
{
176200
superlu_gridexit(gridinfo);
177201
delete gridinfo;
178-
}
202+
};
179203

180204
//----------------------------------------------------------------------------
181205
template <typename T>
182206
SuperLUDistSolver<T>::SuperLUDistSolver(
183-
std::shared_ptr<const SuperLUDistMatrix<T>> A, bool verbose)
207+
std::shared_ptr<const SuperLUDistMatrix<T>> A)
184208
: _superlu_matA(std::move(A)),
209+
_options(
210+
[]
211+
{
212+
auto o = std::make_unique<
213+
SuperLUDistStructs::superlu_dist_options_t>();
214+
set_default_options_dist(o.get());
215+
o->PrintStat = NO;
216+
return o;
217+
}()),
185218
_gridinfo(
186219
[comm = _superlu_matA->matA().comm()]
187220
{
@@ -191,28 +224,105 @@ SuperLUDistSolver<T>::SuperLUDistSolver(
191224
new SuperLUDistStructs::gridinfo_t, GridInfoDeleter{});
192225
superlu_gridinit(comm, nprow, npcol, p.get());
193226
return p;
194-
}()),
195-
_verbose(verbose)
227+
}())
228+
{
229+
}
230+
231+
template <typename T>
232+
void SuperLUDistSolver<T>::set_options(
233+
SuperLUDistStructs::superlu_dist_options_t options)
234+
{
235+
_options = std::make_unique<SuperLUDistStructs::superlu_dist_options_t>(
236+
std::move(options));
237+
}
238+
239+
template <typename T>
240+
void SuperLUDistSolver<T>::set_option(std::string name, std::string value)
196241
{
242+
spdlog::info("Attempting to set option {} to {}", name, value);
243+
const std::map<std::string, std::reference_wrapper<yes_no_t>> map_bool
244+
= {{"Equil", _options->Equil},
245+
{"DiagInv", _options->DiagInv},
246+
{"SymmetricMode", _options->SymmetricMode},
247+
{"PivotGrowth", _options->PivotGrowth},
248+
{"ConditionNumber", _options->ConditionNumber},
249+
{"ReplaceTinyPivot", _options->ReplaceTinyPivot},
250+
{"SolveInitialized", _options->SolveInitialized},
251+
{"RefineInitialized", _options->RefineInitialized},
252+
{"PrintStat", _options->PrintStat},
253+
{"lookahead_etree", _options->lookahead_etree},
254+
{"SymPattern", _options->SymPattern},
255+
{"Use_TensorCore", _options->Use_TensorCore},
256+
{"Algo3d", _options->Algo3d}};
257+
258+
// Search in map_bool first
259+
auto it = map_bool.find(name);
260+
if (it != map_bool.end())
261+
{
262+
if (value == "YES")
263+
{
264+
spdlog::info("Set {} to YES", name);
265+
it->second.get() = YES;
266+
}
267+
else if (value == "NO")
268+
{
269+
spdlog::info("Set {} to NO", name);
270+
it->second.get() = NO;
271+
}
272+
else
273+
{
274+
throw std::runtime_error("Boolean values must be string 'YES' or 'NO'");
275+
}
276+
}
277+
278+
// Search some enum types
279+
if (name == "Fact")
280+
{
281+
option_setter(
282+
name, _options->Fact,
283+
{DOFACT, SamePattern, SamePattern_SameRowPerm, FACTORED},
284+
{"DOFACT", "SamePattern", "SamePattern_SameRowPerm", "FACTORED"},
285+
value);
286+
}
287+
else if (name == "Trans")
288+
{
289+
option_setter(name, _options->Trans, {NOTRANS, TRANS, CONJ},
290+
{"NOTRANS", "TRANS", "CONJ"}, value);
291+
}
292+
else if (name == "ColPerm")
293+
{
294+
option_setter(name, _options->ColPerm,
295+
{NATURAL, MMD_ATA, MMD_AT_PLUS_A, COLAMD, METIS_AT_PLUS_A,
296+
PARMETIS, METIS_ATA, ZOLTAN, MY_PERMC},
297+
{"NATURAL", "MMD_ATA", "MMD_AT_PLUS_A", "COLAMD",
298+
"METIS_AT_PLUS_A", "PARMETIS", "METIS_ATA", "ZOLTAN",
299+
"MY_PERMC"},
300+
value);
301+
}
302+
else if (name == "RowPerm")
303+
{
304+
option_setter(name, _options->RowPerm,
305+
{NOROWPERM, LargeDiag_MC64, LargeDiag_HWPM, MY_PERMR},
306+
{"NOROWPERM", "LargeDiag_MC64", "LargeDiag_HWPM", "MY_PERMR"},
307+
value);
308+
}
309+
else
310+
{
311+
std::runtime_error("Unsupported name");
312+
}
197313
}
198314
//----------------------------------------------------------------------------
199315
template <typename T>
200316
int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
201317
{
318+
common::Timer tsolve("SuperLU Solve");
202319
int_t m = _superlu_matA->supermatrix()->nrow;
203320
int_t m_loc = ((NRformat_loc*)(_superlu_matA->supermatrix()->Store))->m_loc;
204321

205322
// RHS
206323
int_t ldb = m_loc;
207324
int_t nrhs = 1;
208325

209-
superlu_dist_options_t options;
210-
set_default_options_dist(&options);
211-
options.DiagInv = YES;
212-
options.ReplaceTinyPivot = YES;
213-
if (!_verbose)
214-
options.PrintStat = NO;
215-
216326
int info = 0;
217327
SuperLUStat_t stat;
218328
PStatInit(&stat);
@@ -232,12 +342,12 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
232342
dSOLVEstruct_t SOLVEstruct;
233343

234344
spdlog::info("Call SuperLU_DIST pdgssvx()");
235-
pdgssvx(&options, _superlu_matA->supermatrix(), &ScalePermstruct,
345+
pdgssvx(_options.get(), _superlu_matA->supermatrix(), &ScalePermstruct,
236346
u.array().data(), ldb, nrhs, _gridinfo.get(), &LUstruct,
237347
&SOLVEstruct, berr.data(), &stat, &info);
238348

239349
spdlog::info("Finalize solve");
240-
dSolveFinalize(&options, &SOLVEstruct);
350+
dSolveFinalize(_options.get(), &SOLVEstruct);
241351
dScalePermstructFree(&ScalePermstruct);
242352
dLUstructFree(&LUstruct);
243353
}
@@ -251,12 +361,12 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
251361
sSOLVEstruct_t SOLVEstruct;
252362

253363
spdlog::info("Call SuperLU_DIST psgssvx()");
254-
psgssvx(&options, _superlu_matA->supermatrix(), &ScalePermstruct,
364+
psgssvx(_options.get(), _superlu_matA->supermatrix(), &ScalePermstruct,
255365
u.array().data(), ldb, nrhs, _gridinfo.get(), &LUstruct,
256366
&SOLVEstruct, berr.data(), &stat, &info);
257367

258368
spdlog::info("Finalize solve");
259-
sSolveFinalize(&options, &SOLVEstruct);
369+
sSolveFinalize(_options.get(), &SOLVEstruct);
260370
sScalePermstructFree(&ScalePermstruct);
261371
sLUstructFree(&LUstruct);
262372
}
@@ -270,13 +380,13 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
270380
zSOLVEstruct_t SOLVEstruct;
271381

272382
spdlog::info("Call SuperLU_DIST pzgssvx()");
273-
pzgssvx(&options, _superlu_matA->supermatrix(), &ScalePermstruct,
383+
pzgssvx(_options.get(), _superlu_matA->supermatrix(), &ScalePermstruct,
274384
reinterpret_cast<doublecomplex*>(u.array().data()), ldb, nrhs,
275385
_gridinfo.get(), &LUstruct, &SOLVEstruct, berr.data(), &stat,
276386
&info);
277387

278388
spdlog::info("Finalize solve");
279-
zSolveFinalize(&options, &SOLVEstruct);
389+
zSolveFinalize(_options.get(), &SOLVEstruct);
280390
zScalePermstructFree(&ScalePermstruct);
281391
zLUstructFree(&LUstruct);
282392
}
@@ -287,8 +397,7 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
287397
if (info != 0)
288398
spdlog::info("SuperLU_DIST p*gssvx() error: {}", info);
289399

290-
if (_verbose)
291-
PStatPrint(&options, &stat, _gridinfo.get());
400+
PStatPrint(_options.get(), &stat, _gridinfo.get());
292401
PStatFree(&stat);
293402

294403
return info;

cpp/dolfinx/la/superlu_dist.h

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <dolfinx/la/MatrixCSR.h>
1212
#include <dolfinx/la/Vector.h>
1313
#include <memory>
14+
#include <string>
1415

1516
namespace dolfinx::la
1617
{
@@ -19,8 +20,9 @@ class SuperLUDistStructs
1920
{
2021
public:
2122
struct SuperMatrix;
22-
struct gridinfo_t;
2323
struct vec_int_t;
24+
struct gridinfo_t;
25+
struct superlu_dist_options_t;
2426
};
2527

2628
/// Call library cleanup and delete pointer. For use with
@@ -39,21 +41,17 @@ class SuperLUDistMatrix
3941
public:
4042
/// @brief Create SuperLU_DIST matrix operator.
4143
///
42-
/// Handles RAII-type memory management of underlying C objects.
43-
///
4444
/// @tparam T Scalar type.
4545
/// @param A Matrix.
46-
/// @param verbose Verbose output.
47-
SuperLUDistMatrix(std::shared_ptr<const MatrixCSR<T>> A,
48-
bool verbose = false);
46+
SuperLUDistMatrix(std::shared_ptr<const MatrixCSR<T>> A);
4947

5048
/// Copy constructor
5149
SuperLUDistMatrix(const SuperLUDistMatrix&) = delete;
5250

5351
/// Copy assignment
5452
SuperLUDistMatrix& operator=(const SuperLUDistMatrix&) = delete;
5553

56-
/// Get non-const pointer to SuperLU_DIST SuperMatrix.
54+
/// Get pointer to SuperLU_DIST SuperMatrix (non-const).
5755
SuperLUDistStructs::SuperMatrix* supermatrix() const;
5856

5957
/// Get MatrixCSR (const).
@@ -70,9 +68,6 @@ class SuperLUDistMatrix
7068
// Pointer to native SuperMatrix
7169
std::unique_ptr<SuperLUDistStructs::SuperMatrix, SuperMatrixDeleter>
7270
_supermatrix;
73-
74-
// Flag for diagnostic output
75-
bool _verbose;
7671
};
7772

7873
/// Call library cleanup and delete pointer. For use with
@@ -91,38 +86,71 @@ class SuperLUDistSolver
9186
public:
9287
/// @brief Create solver for a SuperLU_DIST matrix operator.
9388
///
94-
/// Solves Au = b using SuperLU_DIST.
89+
/// Solves linear system Au = b via LU decomposition.
90+
///
91+
/// The SuperLU_DIST solver has options set to upstream defaults,
92+
/// except PrintStat (verbose solver output) set to NO.
9593
///
9694
/// @tparam T Scalar type.
97-
/// @param A Matrix to solve for.
98-
/// @param verbose Verbose output.
99-
SuperLUDistSolver(std::shared_ptr<const SuperLUDistMatrix<T>> A,
100-
bool verbose = false);
95+
/// @param A Assembled left-hand side matrix.
96+
SuperLUDistSolver(std::shared_ptr<const SuperLUDistMatrix<T>> A);
10197

10298
/// Copy constructor
10399
SuperLUDistSolver(const SuperLUDistSolver&) = delete;
104100

105101
/// Copy assignment
106102
SuperLUDistSolver& operator=(const SuperLUDistSolver&) = delete;
107103

104+
/// @brief Set solver option (name, value)
105+
///
106+
/// See SuperLU_DIST User's Guide for option names and values.
107+
///
108+
/// @param name Option name.
109+
/// @param value Option value.
110+
void set_option(std::string name, std::string value);
111+
112+
/// @brief Set all solver options (native struct)
113+
///
114+
/// See SuperLU_DIST User's Guide for option names and values.
115+
///
116+
/// Callers must complete the forward declared struct, e.g.:
117+
///
118+
/// ```cpp
119+
/// #include <superlu_defs.h>
120+
/// struct dolfinx::la::SuperLUDistStructs::superlu_dist_options_t
121+
/// : public ::superlu_dist_options_t
122+
/// {
123+
/// };
124+
///
125+
/// SuperLUDistStructs::superlu_dist_options_t options;
126+
/// set_default_options_dist(&options);
127+
/// options.PrintStat = YES;
128+
/// // Setup SuperLUDistMatrix and SuperLUDistSolver
129+
/// solver.set_options(options);
130+
/// ```
131+
///
132+
/// @param options SuperLU_DIST option struct.
133+
void set_options(SuperLUDistStructs::superlu_dist_options_t options);
134+
108135
/// @brief Solve linear system Au = b.
109136
///
110137
/// @param b Right-hand side vector.
111-
/// @param u Solution vector.
112-
/// @returns SuperLU_DIST info flag.
113-
/// @note Vectors must have size and parallel layout that is
114-
/// compatible with `A`.
138+
/// @param u Solution vector, overwritten during solve.
139+
/// @returns SuperLU_DIST info integer.
140+
/// @note The caller must check the return code for success `(== 0)`.
141+
/// @note The caller must `u.scatter_forward()` after the solve.
142+
/// @note Vectors must have size and parallel layout compatible with `A`.
115143
int solve(const Vector<T>& b, Vector<T>& u) const;
116144

117145
private:
118-
// Wrapped SuperLU SuperMatrix
146+
// Assembled left-hand side matrix
119147
std::shared_ptr<const SuperLUDistMatrix<T>> _superlu_matA;
120148

149+
// Pointer to struct superlu_dist_options_t
150+
std::unique_ptr<SuperLUDistStructs::superlu_dist_options_t> _options;
151+
121152
// Pointer to struct gridinfo_t
122153
std::unique_ptr<SuperLUDistStructs::gridinfo_t, GridInfoDeleter> _gridinfo;
123-
124-
// Flag for diagnostic output
125-
bool _verbose;
126154
};
127155
} // namespace dolfinx::la
128156
#endif

0 commit comments

Comments
 (0)