@@ -14,8 +14,10 @@ extern "C"
1414#include < superlu_zdefs.h>
1515}
1616#include < algorithm>
17+ #include < dolfinx/common/Timer.h>
1718#include < dolfinx/la/MatrixCSR.h>
1819#include < dolfinx/la/Vector.h>
20+ #include < ranges>
1921#include < stdexcept>
2022#include < vector>
2123
@@ -46,6 +48,25 @@ namespace
4648template <typename ...>
4749constexpr bool dependent_false_v = false ;
4850
51+ template <typename V, typename W>
52+ void option_setter (std::string_view option_name, W& option,
53+ const std::initializer_list<V> values,
54+ std::initializer_list<std::string_view> value_names,
55+ std::string_view value_in)
56+ {
57+ // TODO: Can be done nicely with std::views::zip in C++23.
58+ for (auto i : std::views::iota (std::size_t {0 }, value_names.size ()))
59+ {
60+ if (value_in == *(value_names.begin () + i))
61+ {
62+ option = *(values.begin () + i);
63+ spdlog::info (" Set {} to {}" , option_name, value_in);
64+ return ;
65+ }
66+ }
67+ throw std::runtime_error (" Unsupported value" );
68+ }
69+
4970std::vector<int_t > col_indices (const auto & A)
5071{
5172 // Local number of non-zeros
@@ -135,15 +156,13 @@ create_supermatrix(const auto& A, auto& rowptr, auto& cols)
135156} // namespace
136157// ----------------------------------------------------------------------------
137158template <typename T>
138- SuperLUDistMatrix<T>::SuperLUDistMatrix(std::shared_ptr<const MatrixCSR<T>> A,
139- bool verbose)
159+ SuperLUDistMatrix<T>::SuperLUDistMatrix(std::shared_ptr<const MatrixCSR<T>> A)
140160 : _matA(std::move(A)),
141161 _cols (
142162 std::make_unique<SuperLUDistStructs::vec_int_t >(col_indices(*_matA))),
143163 _rowptr(
144164 std::make_unique<SuperLUDistStructs::vec_int_t >(row_indices(*_matA))),
145- _supermatrix(create_supermatrix<T>(*_matA, *_rowptr, *_cols)),
146- _verbose(verbose)
165+ _supermatrix(create_supermatrix<T>(*_matA, *_rowptr, *_cols))
147166{
148167}
149168// ----------------------------------------------------------------------------
@@ -170,18 +189,32 @@ struct dolfinx::la::SuperLUDistStructs::gridinfo_t : public ::gridinfo_t
170189{
171190};
172191// ----------------------------------------------------------------------------
192+ struct dolfinx ::la::SuperLUDistStructs::superlu_dist_options_t
193+ : public ::superlu_dist_options_t
194+ {
195+ };
196+ // ----------------------------------------------------------------------------
173197void GridInfoDeleter::operator ()(
174198 SuperLUDistStructs::gridinfo_t * gridinfo) const noexcept
175199{
176200 superlu_gridexit (gridinfo);
177201 delete gridinfo;
178- }
202+ };
179203
180204// ----------------------------------------------------------------------------
181205template <typename T>
182206SuperLUDistSolver<T>::SuperLUDistSolver(
183- std::shared_ptr<const SuperLUDistMatrix<T>> A, bool verbose )
207+ std::shared_ptr<const SuperLUDistMatrix<T>> A)
184208 : _superlu_matA(std::move(A)),
209+ _options (
210+ []
211+ {
212+ auto o = std::make_unique<
213+ SuperLUDistStructs::superlu_dist_options_t >();
214+ set_default_options_dist (o.get ());
215+ o->PrintStat = NO;
216+ return o;
217+ }()),
185218 _gridinfo(
186219 [comm = _superlu_matA->matA ().comm()]
187220 {
@@ -191,28 +224,105 @@ SuperLUDistSolver<T>::SuperLUDistSolver(
191224 new SuperLUDistStructs::gridinfo_t , GridInfoDeleter{});
192225 superlu_gridinit (comm, nprow, npcol, p.get ());
193226 return p;
194- }()),
195- _verbose(verbose)
227+ }())
228+ {
229+ }
230+
231+ template <typename T>
232+ void SuperLUDistSolver<T>::set_options(
233+ SuperLUDistStructs::superlu_dist_options_t options)
234+ {
235+ _options = std::make_unique<SuperLUDistStructs::superlu_dist_options_t >(
236+ std::move (options));
237+ }
238+
239+ template <typename T>
240+ void SuperLUDistSolver<T>::set_option(std::string name, std::string value)
196241{
242+ spdlog::info (" Attempting to set option {} to {}" , name, value);
243+ const std::map<std::string, std::reference_wrapper<yes_no_t >> map_bool
244+ = {{" Equil" , _options->Equil },
245+ {" DiagInv" , _options->DiagInv },
246+ {" SymmetricMode" , _options->SymmetricMode },
247+ {" PivotGrowth" , _options->PivotGrowth },
248+ {" ConditionNumber" , _options->ConditionNumber },
249+ {" ReplaceTinyPivot" , _options->ReplaceTinyPivot },
250+ {" SolveInitialized" , _options->SolveInitialized },
251+ {" RefineInitialized" , _options->RefineInitialized },
252+ {" PrintStat" , _options->PrintStat },
253+ {" lookahead_etree" , _options->lookahead_etree },
254+ {" SymPattern" , _options->SymPattern },
255+ {" Use_TensorCore" , _options->Use_TensorCore },
256+ {" Algo3d" , _options->Algo3d }};
257+
258+ // Search in map_bool first
259+ auto it = map_bool.find (name);
260+ if (it != map_bool.end ())
261+ {
262+ if (value == " YES" )
263+ {
264+ spdlog::info (" Set {} to YES" , name);
265+ it->second .get () = YES;
266+ }
267+ else if (value == " NO" )
268+ {
269+ spdlog::info (" Set {} to NO" , name);
270+ it->second .get () = NO;
271+ }
272+ else
273+ {
274+ throw std::runtime_error (" Boolean values must be string 'YES' or 'NO'" );
275+ }
276+ }
277+
278+ // Search some enum types
279+ if (name == " Fact" )
280+ {
281+ option_setter (
282+ name, _options->Fact ,
283+ {DOFACT, SamePattern, SamePattern_SameRowPerm, FACTORED},
284+ {" DOFACT" , " SamePattern" , " SamePattern_SameRowPerm" , " FACTORED" },
285+ value);
286+ }
287+ else if (name == " Trans" )
288+ {
289+ option_setter (name, _options->Trans , {NOTRANS, TRANS, CONJ},
290+ {" NOTRANS" , " TRANS" , " CONJ" }, value);
291+ }
292+ else if (name == " ColPerm" )
293+ {
294+ option_setter (name, _options->ColPerm ,
295+ {NATURAL, MMD_ATA, MMD_AT_PLUS_A, COLAMD, METIS_AT_PLUS_A,
296+ PARMETIS, METIS_ATA, ZOLTAN, MY_PERMC},
297+ {" NATURAL" , " MMD_ATA" , " MMD_AT_PLUS_A" , " COLAMD" ,
298+ " METIS_AT_PLUS_A" , " PARMETIS" , " METIS_ATA" , " ZOLTAN" ,
299+ " MY_PERMC" },
300+ value);
301+ }
302+ else if (name == " RowPerm" )
303+ {
304+ option_setter (name, _options->RowPerm ,
305+ {NOROWPERM, LargeDiag_MC64, LargeDiag_HWPM, MY_PERMR},
306+ {" NOROWPERM" , " LargeDiag_MC64" , " LargeDiag_HWPM" , " MY_PERMR" },
307+ value);
308+ }
309+ else
310+ {
311+ std::runtime_error (" Unsupported name" );
312+ }
197313}
198314// ----------------------------------------------------------------------------
199315template <typename T>
200316int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
201317{
318+ common::Timer tsolve (" SuperLU Solve" );
202319 int_t m = _superlu_matA->supermatrix ()->nrow ;
203320 int_t m_loc = ((NRformat_loc*)(_superlu_matA->supermatrix ()->Store ))->m_loc ;
204321
205322 // RHS
206323 int_t ldb = m_loc;
207324 int_t nrhs = 1 ;
208325
209- superlu_dist_options_t options;
210- set_default_options_dist (&options);
211- options.DiagInv = YES;
212- options.ReplaceTinyPivot = YES;
213- if (!_verbose)
214- options.PrintStat = NO;
215-
216326 int info = 0 ;
217327 SuperLUStat_t stat;
218328 PStatInit (&stat);
@@ -232,12 +342,12 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
232342 dSOLVEstruct_t SOLVEstruct;
233343
234344 spdlog::info (" Call SuperLU_DIST pdgssvx()" );
235- pdgssvx (&options , _superlu_matA->supermatrix (), &ScalePermstruct,
345+ pdgssvx (_options. get () , _superlu_matA->supermatrix (), &ScalePermstruct,
236346 u.array ().data (), ldb, nrhs, _gridinfo.get (), &LUstruct,
237347 &SOLVEstruct, berr.data (), &stat, &info);
238348
239349 spdlog::info (" Finalize solve" );
240- dSolveFinalize (&options , &SOLVEstruct);
350+ dSolveFinalize (_options. get () , &SOLVEstruct);
241351 dScalePermstructFree (&ScalePermstruct);
242352 dLUstructFree (&LUstruct);
243353 }
@@ -251,12 +361,12 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
251361 sSOLVEstruct_t SOLVEstruct;
252362
253363 spdlog::info (" Call SuperLU_DIST psgssvx()" );
254- psgssvx (&options , _superlu_matA->supermatrix (), &ScalePermstruct,
364+ psgssvx (_options. get () , _superlu_matA->supermatrix (), &ScalePermstruct,
255365 u.array ().data (), ldb, nrhs, _gridinfo.get (), &LUstruct,
256366 &SOLVEstruct, berr.data (), &stat, &info);
257367
258368 spdlog::info (" Finalize solve" );
259- sSolveFinalize (&options , &SOLVEstruct);
369+ sSolveFinalize (_options. get () , &SOLVEstruct);
260370 sScalePermstructFree (&ScalePermstruct);
261371 sLUstructFree (&LUstruct);
262372 }
@@ -270,13 +380,13 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
270380 zSOLVEstruct_t SOLVEstruct;
271381
272382 spdlog::info (" Call SuperLU_DIST pzgssvx()" );
273- pzgssvx (&options , _superlu_matA->supermatrix (), &ScalePermstruct,
383+ pzgssvx (_options. get () , _superlu_matA->supermatrix (), &ScalePermstruct,
274384 reinterpret_cast <doublecomplex*>(u.array ().data ()), ldb, nrhs,
275385 _gridinfo.get (), &LUstruct, &SOLVEstruct, berr.data (), &stat,
276386 &info);
277387
278388 spdlog::info (" Finalize solve" );
279- zSolveFinalize (&options , &SOLVEstruct);
389+ zSolveFinalize (_options. get () , &SOLVEstruct);
280390 zScalePermstructFree (&ScalePermstruct);
281391 zLUstructFree (&LUstruct);
282392 }
@@ -287,8 +397,7 @@ int SuperLUDistSolver<T>::solve(const la::Vector<T>& b, la::Vector<T>& u) const
287397 if (info != 0 )
288398 spdlog::info (" SuperLU_DIST p*gssvx() error: {}" , info);
289399
290- if (_verbose)
291- PStatPrint (&options, &stat, _gridinfo.get ());
400+ PStatPrint (_options.get (), &stat, _gridinfo.get ());
292401 PStatFree (&stat);
293402
294403 return info;
0 commit comments