20#ifndef OPM_AMGX_INTERFACE_HPP
21#define OPM_AMGX_INTERFACE_HPP
23#include <opm/common/ErrorMacros.hpp>
24#include <opm/simulators/linalg/gpuistl/GpuSparseMatrixWrapper.hpp>
25#include <opm/simulators/linalg/gpuistl/GpuVector.hpp>
26#include <opm/simulators/linalg/gpuistl/detail/gpu_safe_call.hpp>
27#include <opm/simulators/linalg/gpuistl/detail/gpu_type_detection.hpp>
30#include <cuda_runtime.h>
43class AmgxError :
public std::runtime_error
46 explicit AmgxError(
const std::string& msg)
47 : std::runtime_error(msg)
64 AMGX_RC err,
const std::string& expression,
const std::string& file,
const std::string& function,
int line)
66 char amgx_err_msg[4096];
67 AMGX_get_error_string(err, amgx_err_msg,
sizeof(amgx_err_msg));
70 "AMGX error in expression: {}\nError code: {}\nError message: {}\nLocation: {}:{} in function {}",
72 static_cast<int>(err),
92amgxSafeCall(AMGX_RC rc,
const std::string& expression,
const std::string& file,
const std::string& function,
int line)
94 if (rc != AMGX_RC_OK) {
107#define OPM_AMGX_SAFE_CALL(expr) ::Opm::gpuistl::amgxSafeCall((expr), #expr, __FILE__, __func__, __LINE__)
132 OPM_AMGX_SAFE_CALL(AMGX_initialize());
143 OPM_AMGX_SAFE_CALL(AMGX_finalize());
153 static AMGX_config_handle
createConfig(
const std::string& config_string)
155 AMGX_config_handle config;
156 OPM_AMGX_SAFE_CALL(AMGX_config_create(&config, config_string.c_str()));
169 AMGX_resources_handle resources;
170 OPM_AMGX_SAFE_CALL(AMGX_resources_create_simple(&resources, config));
183 static AMGX_solver_handle
createSolver(AMGX_resources_handle resources, AMGX_Mode mode, AMGX_config_handle config)
185 AMGX_solver_handle solver;
186 OPM_AMGX_SAFE_CALL(AMGX_solver_create(&solver, resources, mode, config));
198 static AMGX_matrix_handle
createMatrix(AMGX_resources_handle resources, AMGX_Mode mode)
200 AMGX_matrix_handle matrix;
201 OPM_AMGX_SAFE_CALL(AMGX_matrix_create(&matrix, resources, mode));
213 static AMGX_vector_handle
createVector(AMGX_resources_handle resources, AMGX_Mode mode)
215 AMGX_vector_handle vector;
216 OPM_AMGX_SAFE_CALL(AMGX_vector_create(&vector, resources, mode));
229 OPM_AMGX_SAFE_CALL(AMGX_config_destroy(config));
247 OPM_AMGX_SAFE_CALL(AMGX_resources_destroy(resources));
260 OPM_AMGX_SAFE_CALL(AMGX_solver_destroy(solver));
271 template <
typename MatrixType>
272 static void destroyMatrix(AMGX_matrix_handle amgx_matrix,
const MatrixType& matrix)
276 using T =
typename MatrixType::field_type;
277 const T* values = &(matrix[0][0][0][0]);
278 OPM_AMGX_SAFE_CALL(AMGX_unpin_memory(
const_cast<T*
>(values)));
282 OPM_AMGX_SAFE_CALL(AMGX_matrix_destroy(amgx_matrix));
295 OPM_AMGX_SAFE_CALL(AMGX_vector_destroy(vector));
309 template <
typename T>
314 OPM_AMGX_SAFE_CALL(AMGX_vector_get_size(amgx_vec, &n, &block_dim));
316 if (n > 0 &&
static_cast<size_t>(n * block_dim) != gpu_vec.dim()) {
317 throw AmgxError(fmt::format(
"Vector size mismatch in updateAmgxFromGpuVector: "
318 "AMGX vector size {} vs. GpuVector size {}",
324 const T* device_ptr = gpu_vec.data();
327 OPM_AMGX_SAFE_CALL(AMGX_vector_upload(amgx_vec, n, block_dim,
const_cast<T*
>(device_ptr)));
339 template <
typename T>
344 OPM_AMGX_SAFE_CALL(AMGX_vector_get_size(amgx_vec, &n, &block_dim));
346 if (
static_cast<size_t>(n * block_dim) != gpu_vec.dim()) {
347 throw AmgxError(fmt::format(
"Vector size mismatch in updateGpuVectorFromAmgx: "
348 "AMGX vector size {} vs. GpuVector size {}",
354 T* dst_device_ptr = gpu_vec.data();
357 OPM_AMGX_SAFE_CALL(AMGX_vector_download(amgx_vec, dst_device_ptr));
369 template <
typename VectorType>
378 const int N = vec.size();
379 const int block_size = 1;
382 OPM_AMGX_SAFE_CALL(AMGX_vector_upload(amgx_vec, N, block_size, &vec[0][0]));
395 template <
typename VectorType>
403 OPM_AMGX_SAFE_CALL(AMGX_vector_download(amgx_vec, &vec[0][0]));
416 template <
typename T>
418 AMGX_matrix_handle amgxMatrix)
427 const int* row_ptrs = gpuSparseMatrix.
getRowIndices().data();
432 AMGX_matrix_upload_all(amgxMatrix, n, nnz, block_size, block_size, row_ptrs, col_indices, values,
nullptr));
445 template <
typename T>
447 AMGX_matrix_handle amgxMatrix)
457 OPM_AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(amgxMatrix, n, nnz, values,
nullptr));
473 template <
typename T>
477 int n, nnz, block_sizex, block_sizey;
478 OPM_AMGX_SAFE_CALL(AMGX_matrix_get_size(amgxMatrix, &n, &block_sizex, &block_sizey));
479 OPM_AMGX_SAFE_CALL(AMGX_matrix_get_nnz(amgxMatrix, &nnz));
483 int* temp_col_indices;
484 OPM_GPU_SAFE_CALL(cudaMalloc(&temp_row_ptrs, (n + 1) *
sizeof(
int)));
485 OPM_GPU_SAFE_CALL(cudaMalloc(&temp_col_indices, nnz *
sizeof(
int)));
491 void* diag_data_ptr =
nullptr;
494 OPM_AMGX_SAFE_CALL(AMGX_matrix_download_all(amgxMatrix,
501 OPM_GPU_SAFE_CALL(cudaFree(temp_row_ptrs));
502 OPM_GPU_SAFE_CALL(cudaFree(temp_col_indices));
514 template <
typename MatrixType>
524 const int block_size = 1;
527 std::vector<int> row_ptrs(N + 1);
528 std::vector<int> col_indices(nnz);
532 for (
auto row = matrix.begin(); row != matrix.end(); ++row) {
533 for (
auto col = row->begin(); col != row->end(); ++col) {
534 col_indices[pos++] = col.index();
536 row_ptrs[row.index() + 1] = pos;
540 using T =
typename MatrixType::field_type;
541 const T* values = &(matrix[0][0][0][0]);
549 OPM_AMGX_SAFE_CALL(AMGX_pin_memory(
const_cast<T*
>(values),
sizeof(T) * nnz * block_size * block_size));
552 OPM_AMGX_SAFE_CALL(AMGX_matrix_upload_all(amgx_matrix,
559 const_cast<T*
>(values),
576 OPM_AMGX_SAFE_CALL(AMGX_vector_set_zero(amgx_vector, N, block_size));
589 template <
typename MatrixType>
597 using T =
typename MatrixType::field_type;
598 const T* values = &(matrix[0][0][0][0]);
599 OPM_AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(
600 amgx_matrix, matrix.N(), matrix.nonzeroes(),
const_cast<T*
>(values),
nullptr));
612 template <
typename MatrixFieldType,
typename VectorFieldType>
615 if constexpr (std::is_same_v<MatrixFieldType, double> && std::is_same_v<VectorFieldType, double>) {
616 return AMGX_mode_dDDI;
617 }
else if constexpr (std::is_same_v<MatrixFieldType, float> && std::is_same_v<VectorFieldType, double>) {
618 return AMGX_mode_dDFI;
619 }
else if constexpr (std::is_same_v<MatrixFieldType, float> && std::is_same_v<VectorFieldType, float>) {
620 return AMGX_mode_dFFI;
622 OPM_THROW(std::runtime_error,
623 "Unsupported combination of matrix and vector types for AMGX: "
624 + std::string(
typeid(MatrixFieldType).name()) +
" and "
625 + std::string(
typeid(VectorFieldType).name()));
Exception class for AMGX errors.
Definition AmgxInterface.hpp:44
Unified interface for AMGX operations with both CPU and GPU data structures.
Definition AmgxInterface.hpp:122
static void updateAmgxMatrixCoefficientsFromGpuSparseMatrix(const GpuSparseMatrixWrapper< T > &gpuSparseMatrix, AMGX_matrix_handle amgxMatrix)
Update only the coefficient values of an AMGX matrix from a GpuSparseMatrix.
Definition AmgxInterface.hpp:446
static AMGX_Mode determineAmgxMode()
Determine the appropriate AMGX mode based on matrix and vector field types.
Definition AmgxInterface.hpp:613
static void updateGpuVectorFromAmgx(AMGX_vector_handle amgx_vec, GpuVector< T > &gpu_vec)
Update a GpuVector from an AMGX vector (device-to-device transfer).
Definition AmgxInterface.hpp:340
static void destroyResources(AMGX_resources_handle resources)
Destroy an AMGX resources handle.
Definition AmgxInterface.hpp:244
static AMGX_resources_handle createResources(AMGX_config_handle config)
Create AMGX resources from a config.
Definition AmgxInterface.hpp:167
static void initialize()
Initialize the AMGX library.
Definition AmgxInterface.hpp:130
static void destroyVector(AMGX_vector_handle vector)
Destroy an AMGX vector handle.
Definition AmgxInterface.hpp:292
static AMGX_solver_handle createSolver(AMGX_resources_handle resources, AMGX_Mode mode, AMGX_config_handle config)
Create an AMGX solver.
Definition AmgxInterface.hpp:183
static void updateMatrixValues(const MatrixType &matrix, AMGX_matrix_handle amgx_matrix)
Update matrix values in AMGX.
Definition AmgxInterface.hpp:590
static void updateAmgxFromGpuVector(const GpuVector< T > &gpu_vec, AMGX_vector_handle amgx_vec)
Update an AMGX vector from a GpuVector (device-to-device transfer).
Definition AmgxInterface.hpp:310
static void initializeVector(int N, int block_size, AMGX_vector_handle amgx_vector)
Initialize an AMGX vector with zeros.
Definition AmgxInterface.hpp:574
static void destroyConfig(AMGX_config_handle config)
Destroy an AMGX config handle.
Definition AmgxInterface.hpp:226
static AMGX_config_handle createConfig(const std::string &config_string)
Create an AMGX config handle from a configuration string.
Definition AmgxInterface.hpp:153
static void destroySolver(AMGX_solver_handle solver)
Destroy an AMGX solver handle.
Definition AmgxInterface.hpp:257
static AMGX_matrix_handle createMatrix(AMGX_resources_handle resources, AMGX_Mode mode)
Create an AMGX matrix.
Definition AmgxInterface.hpp:198
static void initializeMatrix(const MatrixType &matrix, AMGX_matrix_handle amgx_matrix)
Initialize an AMGX matrix from any matrix type (CPU or GPU).
Definition AmgxInterface.hpp:515
static void updateAmgxMatrixFromGpuSparseMatrix(const GpuSparseMatrixWrapper< T > &gpuSparseMatrix, AMGX_matrix_handle amgxMatrix)
Update an AMGX matrix from a GpuSparseMatrixWrapper (device-to-device transfer).
Definition AmgxInterface.hpp:417
static AMGX_vector_handle createVector(AMGX_resources_handle resources, AMGX_Mode mode)
Create an AMGX vector.
Definition AmgxInterface.hpp:213
static void finalize()
Finalize the AMGX library.
Definition AmgxInterface.hpp:141
static void destroyMatrix(AMGX_matrix_handle amgx_matrix, const MatrixType &matrix)
Destroy an AMGX matrix handle.
Definition AmgxInterface.hpp:272
static void updateGpuSparseMatrixFromAmgxMatrix(AMGX_matrix_handle amgxMatrix, GpuSparseMatrixWrapper< T > &gpuSparseMatrix)
Update a GpuSparseMatrixWrapper from an AMGX matrix (device-to-device transfer).
Definition AmgxInterface.hpp:474
static void transferVectorToAmgx(const VectorType &vec, AMGX_vector_handle amgx_vec)
Transfer vector to AMGX from any vector type (CPU or GPU).
Definition AmgxInterface.hpp:370
static void transferVectorFromAmgx(AMGX_vector_handle amgx_vec, VectorType &vec)
Transfer vector from AMGX to any vector type (CPU or GPU).
Definition AmgxInterface.hpp:396
The GpuSparseMatrixWrapper Checks CUDA/HIP version and dispatches a version either using the old or t...
Definition GpuSparseMatrixWrapper.hpp:62
GpuVector< int > & getRowIndices()
getRowIndices returns the row indices used to represent the BSR structure.
Definition GpuSparseMatrixWrapper.hpp:271
GpuVector< T > & getNonZeroValues()
getNonZeroValues returns the GPU vector containing the non-zero values (ordered by block)
Definition GpuSparseMatrixWrapper.hpp:251
std::size_t N() const
N returns the number of rows (which is equal to the number of columns).
Definition GpuSparseMatrixWrapper.hpp:232
std::size_t nonzeroes() const
nonzeroes behaves as the Dune::BCRSMatrix::nonzeros() function and returns the number of non zero blo...
Definition GpuSparseMatrixWrapper.hpp:241
GpuVector< int > & getColumnIndices()
getColumnIndices returns the column indices used to represent the BSR structure.
Definition GpuSparseMatrixWrapper.hpp:291
std::size_t blockSize() const
blockSize size of the blocks
Definition GpuSparseMatrixWrapper.hpp:320
Definition gpu_type_detection.hpp:30
int to_int(std::size_t s)
to_int converts a (on most relevant platforms) 64 bits unsigned size_t to a signed 32 bits signed int
Definition safe_conversion.hpp:56
A small, fixed‑dimension MiniVector class backed by std::array that can be used in both host and CUDA...
Definition AmgxInterface.hpp:38
std::string getAmgxErrorMessage(AMGX_RC err, const std::string &expression, const std::string &file, const std::string &function, int line)
Get a descriptive error message for an AMGX error code.
Definition AmgxInterface.hpp:63
void amgxSafeCall(AMGX_RC rc, const std::string &expression, const std::string &file, const std::string &function, int line)
Safe call wrapper for AMGX functions.
Definition AmgxInterface.hpp:92
Type trait to detect if a type is a GPU type.
Definition gpu_type_detection.hpp:40