// Copyright (c) 2010-2023, Lawrence Livermore National Security, LLC. Produced // at the Lawrence Livermore National Laboratory. All Rights reserved. See files // LICENSE and NOTICE for details. LLNL-CODE-806117. // // This file is part of the MFEM library. For more information and source code // availability visit https://mfem.org. // // MFEM is free software; you can redistribute it and/or modify it under the // terms of the BSD-3 license. We welcome feedback and contributions, see file // CONTRIBUTING.md for details. // Implementation of the MFEM wrapper for Nvidia's multigrid library, AmgX // // This work is partially based on: // // Pi-Yueh Chuang and Lorena A. Barba (2017). // AmgXWrapper: An interface between PETSc and the NVIDIA AmgX library. // J. Open Source Software, 2(16):280, doi:10.21105/joss.00280 // // See https://github.com/barbagroup/AmgXWrapper. #include "../config/config.hpp" #include "amgxsolver.hpp" #ifdef MFEM_USE_AMGX namespace mfem { int AmgXSolver::count = 0; AMGX_resources_handle AmgXSolver::rsrc = nullptr; AmgXSolver::AmgXSolver() : ConvergenceCheck(false) {}; AmgXSolver::AmgXSolver(const AMGX_MODE amgxMode_, const bool verbose) { amgxMode = amgxMode_; if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;} else { ConvergenceCheck = false;} DefaultParameters(amgxMode, verbose); InitSerial(); } #ifdef MFEM_USE_MPI AmgXSolver::AmgXSolver(const MPI_Comm &comm, const AMGX_MODE amgxMode_, const bool verbose) { std::string config; amgxMode = amgxMode_; if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;} else { ConvergenceCheck = false;} DefaultParameters(amgxMode, verbose); InitExclusiveGPU(comm); } AmgXSolver::AmgXSolver(const MPI_Comm &comm, const int nDevs, const AMGX_MODE amgxMode_, const bool verbose) { std::string config; amgxMode = amgxMode_; if (amgxMode == AmgXSolver::SOLVER) { ConvergenceCheck = true;} else { ConvergenceCheck = false;} DefaultParameters(amgxMode_, verbose); InitMPITeams(comm, nDevs); } #endif AmgXSolver::~AmgXSolver() { if (isInitialized) { Finalize(); } } void AmgXSolver::InitSerial() { count++; mpi_gpu_mode = "serial"; AMGX_SAFE_CALL(AMGX_initialize()); AMGX_SAFE_CALL(AMGX_initialize_plugins()); AMGX_SAFE_CALL(AMGX_install_signal_handler()); MFEM_VERIFY(configSrc != CONFIG_SRC::UNDEFINED, "AmgX configuration is not defined \n"); if (configSrc == CONFIG_SRC::EXTERNAL) { AMGX_SAFE_CALL(AMGX_config_create_from_file(&cfg, amgx_config.c_str())); } else { AMGX_SAFE_CALL(AMGX_config_create(&cfg, amgx_config.c_str())); } AMGX_SAFE_CALL(AMGX_resources_create_simple(&rsrc, cfg)); AMGX_SAFE_CALL(AMGX_solver_create(&solver, rsrc, precision_mode, cfg)); AMGX_SAFE_CALL(AMGX_matrix_create(&AmgXA, rsrc, precision_mode)); AMGX_SAFE_CALL(AMGX_vector_create(&AmgXP, rsrc, precision_mode)); AMGX_SAFE_CALL(AMGX_vector_create(&AmgXRHS, rsrc, precision_mode)); isInitialized = true; } #ifdef MFEM_USE_MPI void AmgXSolver::InitExclusiveGPU(const MPI_Comm &comm) { // If this instance has already been initialized, skip if (isInitialized) { mfem_error("This AmgXSolver instance has been initialized on this process."); } // Note that every MPI rank may talk to a GPU mpi_gpu_mode = "mpi-gpu-exclusive"; gpuProc = 0; // Increment number of AmgX instances count++; MPI_Comm_dup(comm, &gpuWorld); MPI_Comm_size(gpuWorld, &gpuWorldSize); MPI_Comm_rank(gpuWorld, &myGpuWorldRank); // Each rank will only see 1 device call it device 0 nDevs = 1, devID = 0; InitAmgX(); isInitialized = true; } // Initialize for MPI ranks > GPUs, all devices are visible to all of the MPI // ranks void AmgXSolver::InitMPITeams(const MPI_Comm &comm, const int nDevs) { // If this instance has already been initialized, skip if (isInitialized) { mfem_error("This AmgXSolver instance has been initialized on this process."); } mpi_gpu_mode = "mpi-teams"; // Increment number of AmgX instances count++; // Get the name of this node int len; char name[MPI_MAX_PROCESSOR_NAME]; MPI_Get_processor_name(name, &len); nodeName = name; int globalcommrank; MPI_Comm_rank(comm, &globalcommrank); // Initialize communicators and corresponding information InitMPIcomms(comm, nDevs); // Only processes in gpuWorld are required to initialize AmgX if (gpuProc == 0) { InitAmgX(); } isInitialized = true; } #endif void AmgXSolver::ReadParameters(const std::string config, const CONFIG_SRC source) { amgx_config = config; configSrc = source; } void AmgXSolver::SetConvergenceCheck(bool setConvergenceCheck_) { ConvergenceCheck = setConvergenceCheck_; } void AmgXSolver::DefaultParameters(const AMGX_MODE amgxMode_, const bool verbose) { amgxMode = amgxMode_; configSrc = INTERNAL; if (amgxMode == AMGX_MODE::PRECONDITIONER) { amgx_config = "{\n" " \"config_version\": 2, \n" " \"solver\": { \n" " \"solver\": \"AMG\", \n" " \"scope\": \"main\", \n" " \"smoother\": \"JACOBI_L1\", \n" " \"presweeps\": 1, \n" " \"interpolator\": \"D2\", \n" " \"max_row_sum\" : 0.9, \n" " \"strength_threshold\" : 0.25, \n" " \"postsweeps\": 1, \n" " \"max_iters\": 1, \n" " \"cycle\": \"V\""; if (verbose) { amgx_config = amgx_config + ",\n" " \"obtain_timings\": 1, \n" " \"print_grid_stats\": 1, \n" " \"monitor_residual\": 1, \n" " \"print_solve_stats\": 1 \n"; } else { amgx_config = amgx_config + "\n"; } amgx_config = amgx_config + " }\n" + "}\n"; // use a zero initial guess in Mult() iterative_mode = false; } else if (amgxMode == AMGX_MODE::SOLVER) { amgx_config = "{ \n" " \"config_version\": 2, \n" " \"solver\": { \n" " \"preconditioner\": { \n" " \"solver\": \"AMG\", \n" " \"smoother\": { \n" " \"scope\": \"jacobi\", \n" " \"solver\": \"JACOBI_L1\" \n" " }, \n" " \"presweeps\": 1, \n" " \"interpolator\": \"D2\", \n" " \"max_row_sum\" : 0.9, \n" " \"strength_threshold\" : 0.25, \n" " \"max_iters\": 1, \n" " \"scope\": \"amg\", \n" " \"max_levels\": 100, \n" " \"cycle\": \"V\", \n" " \"postsweeps\": 1 \n" " }, \n" " \"solver\": \"PCG\", \n" " \"max_iters\": 150, \n" " \"convergence\": \"RELATIVE_INI_CORE\", \n" " \"scope\": \"main\", \n" " \"tolerance\": 1e-12, \n" " \"monitor_residual\": 1, \n" " \"norm\": \"L2\" "; if (verbose) { amgx_config = amgx_config + ", \n" " \"obtain_timings\": 1, \n" " \"print_grid_stats\": 1, \n" " \"print_solve_stats\": 1 \n"; } else { amgx_config = amgx_config + "\n"; } amgx_config = amgx_config + " } \n" + "} \n"; // use the user-specified vector as an initial guess in Mult() iterative_mode = true; } else { mfem_error("AmgX mode not supported \n"); } } // Sets up AmgX library for MPI builds #ifdef MFEM_USE_MPI void AmgXSolver::InitAmgX() { // Set up once if (count == 1) { AMGX_SAFE_CALL(AMGX_initialize()); AMGX_SAFE_CALL(AMGX_initialize_plugins()); AMGX_SAFE_CALL(AMGX_install_signal_handler()); AMGX_SAFE_CALL(AMGX_register_print_callback( [](const char *msg, int length)->void { int irank; MPI_Comm_rank(MPI_COMM_WORLD, &irank); if (irank == 0) { mfem::out< localSize) // there are more devices than processes { MFEM_WARNING("CUDA devices on the node " << nodeName.c_str() << " are more than the MPI processes launched. Only "<< nDevs << " devices will be used.\n"); devID = myLocalRank; gpuProc = 0; } else // in case there are more ranks than devices { int nBasic = localSize / nDevs, nRemain = localSize % nDevs; if (myLocalRank < (nBasic+1)*nRemain) { devID = myLocalRank / (nBasic + 1); if (myLocalRank % (nBasic + 1) == 0) { gpuProc = 0; } } else { devID = (myLocalRank - (nBasic+1)*nRemain) / nBasic + nRemain; if ((myLocalRank - (nBasic+1)*nRemain) % nBasic == 0) { gpuProc = 0; } } } } void AmgXSolver::GatherArray(const Array &inArr, Array &outArr, const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const { // Calculate number of elements to be collected from each process Array Apart(mpiTeamSz); int locAsz = inArr.Size(); MPI_Gather(&locAsz, 1, MPI_INT, Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm); MPI_Barrier(mpiTeamComm); // Determine stride for process (to be used by root) Array Adisp(mpiTeamSz); int myid; MPI_Comm_rank(mpiTeamComm, &myid); if (myid == 0) { Adisp[0] = 0; for (int i=1; i Apart(mpiTeamSz); int locAsz = inArr.Size(); MPI_Gather(&locAsz, 1, MPI_INT, Apart.HostWrite(),1, MPI_INT,0,mpiTeamComm); MPI_Barrier(mpiTeamComm); // Determine stride for process (to be used by root) Array Adisp(mpiTeamSz); int myid; MPI_Comm_rank(mpiTeamComm, &myid); if (myid == 0) { Adisp[0] = 0; for (int i=1; i &inArr, Array &outArr, const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const { // Calculate number of elements to be collected from each process Array Apart(mpiTeamSz); int locAsz = inArr.Size(); MPI_Gather(&locAsz, 1, MPI_INT, Apart.GetData(),1, MPI_INT,0,mpiTeamComm); MPI_Barrier(mpiTeamComm); // Determine stride for process (to be used by root) Array Adisp(mpiTeamSz); int myid; MPI_Comm_rank(mpiTeamComm, &myid); if (myid == 0) { Adisp[0] = 0; for (int i=1; i &inArr, Array &outArr, const int mpiTeamSz, const MPI_Comm &mpiTeamComm) const { // Calculate number of elements to be collected from each process Array Apart(mpiTeamSz); int locAsz = inArr.Size(); MPI_Gather(&locAsz, 1, MPI_INT, Apart.GetData(),1, MPI_INT,0,mpiTeamComm); MPI_Barrier(mpiTeamComm); // Determine stride for process Array Adisp(mpiTeamSz); int myid; MPI_Comm_rank(mpiTeamComm, &myid); if (myid == 0) { Adisp[0] = 0; for (int i=1; i &Apart, Array &Adisp) const { // Calculate number of elements to be collected from each process int locAsz = inArr.Size(); MPI_Allgather(&locAsz, 1, MPI_INT, Apart.HostWrite(),1, MPI_INT, mpiTeamComm); MPI_Barrier(mpiTeamComm); // Determine stride for process Adisp[0] = 0; for (int i=1; i &Apart, Array &Adisp) const { MPI_Scatterv(inArr.HostRead(),Apart.HostRead(),Adisp.HostRead(), MPI_DOUBLE,outArr.HostWrite(),outArr.Size(), MPI_DOUBLE, 0, mpiTeamComm); } #endif void AmgXSolver::SetMatrix(const SparseMatrix &in_A, const bool update_mat) { if (update_mat == false) { AMGX_SAFE_CALL(AMGX_matrix_upload_all(AmgXA, in_A.Height(), in_A.NumNonZeroElems(), 1, 1, in_A.ReadI(), in_A.ReadJ(), in_A.ReadData(), NULL)); AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA)); AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA)); AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA)); } else { AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, in_A.Height(), in_A.NumNonZeroElems(), in_A.ReadData(), NULL)); } } #ifdef MFEM_USE_MPI void AmgXSolver::SetMatrix(const HypreParMatrix &A, const bool update_mat) { // Require hypre >= 2.16. #if MFEM_HYPRE_VERSION < 21600 mfem_error("Hypre version 2.16+ is required when using AmgX \n"); #endif // Ensure HypreParMatrix is on the host A.HostRead(); hypre_ParCSRMatrix * A_ptr = (hypre_ParCSRMatrix *)const_cast(A); hypre_CSRMatrix *A_csr = hypre_MergeDiagAndOffd(A_ptr); A.HypreRead(); Array loc_A(A_csr->data, (int)A_csr->num_nonzeros); const Array loc_I(A_csr->i, (int)A_csr->num_rows+1); // Column index must be int64_t so we must promote here Array loc_J((int)A_csr->num_nonzeros); for (int i=0; inum_nonzeros; ++i) { loc_J[i] = A_csr->big_j[i]; } // Assumes one GPU per MPI rank if (mpi_gpu_mode=="mpi-gpu-exclusive") { SetMatrixMPIGPUExclusive(A, loc_A, loc_I, loc_J, update_mat); // Free A_csr data from hypre_MergeDiagAndOffd method hypre_CSRMatrixDestroy(A_csr); return; } // Assumes teams of MPI ranks are sharing a GPU if (mpi_gpu_mode == "mpi-teams") { SetMatrixMPITeams(A, loc_A, loc_I, loc_J, update_mat); // Free A_csr data from hypre_MergeDiagAndOffd method hypre_CSRMatrixDestroy(A_csr); return; } mfem_error("Unsupported MPI_GPU combination \n"); } void AmgXSolver::SetMatrixMPIGPUExclusive(const HypreParMatrix &A, const Array &loc_A, const Array &loc_I, const Array &loc_J, const bool update_mat) { // Create a vector of offsets describing matrix row partitions Array rowPart(gpuWorldSize+1); rowPart = 0.0; int64_t myStart = A.GetRowStarts()[0]; MPI_Allgather(&myStart, 1, MPI_INT64_T, rowPart.GetData(),1, MPI_INT64_T ,gpuWorld); MPI_Barrier(gpuWorld); rowPart[gpuWorldSize] = A.M(); const int nGlobalRows = A.M(); const int local_rows = loc_I.Size()-1; const int num_nnz = loc_I[local_rows]; if (update_mat == false) { AMGX_distribution_handle dist; AMGX_SAFE_CALL(AMGX_distribution_create(&dist, cfg)); AMGX_SAFE_CALL(AMGX_distribution_set_partition_data(dist, AMGX_DIST_PARTITION_OFFSETS, rowPart.GetData())); AMGX_SAFE_CALL(AMGX_matrix_upload_distributed(AmgXA, nGlobalRows, local_rows, num_nnz, 1, 1, loc_I.Read(), loc_J.Read(), loc_A.Read(), NULL, dist)); AMGX_SAFE_CALL(AMGX_distribution_destroy(dist)); MPI_Barrier(gpuWorld); AMGX_SAFE_CALL(AMGX_solver_setup(solver, AmgXA)); AMGX_SAFE_CALL(AMGX_vector_bind(AmgXP, AmgXA)); AMGX_SAFE_CALL(AMGX_vector_bind(AmgXRHS, AmgXA)); } else { AMGX_SAFE_CALL(AMGX_matrix_replace_coefficients(AmgXA, nGlobalRows, num_nnz, loc_A, NULL)); } } void AmgXSolver::SetMatrixMPITeams(const HypreParMatrix &A, const Array &loc_A, const Array &loc_I, const Array &loc_J, const bool update_mat) { // The following arrays hold the consolidated diagonal + off-diagonal matrix // data Array all_I; Array all_J; Array all_A; // Determine array sizes int J_allsz(0), all_NNZ(0), nDevRows(0); const int loc_row_len = std::abs(A.RowPart()[1] - A.RowPart()[0]); // end of row partition const int loc_Jz_sz = loc_J.Size(); const int loc_A_sz = loc_A.Size(); MPI_Reduce(&loc_row_len, &nDevRows, 1, MPI_INT, MPI_SUM, 0, devWorld); MPI_Reduce(&loc_Jz_sz, &J_allsz, 1, MPI_INT, MPI_SUM, 0, devWorld); MPI_Reduce(&loc_A_sz, &all_NNZ, 1, MPI_INT, MPI_SUM, 0, devWorld); MPI_Barrier(devWorld); if (myDevWorldRank == 0) { all_I.SetSize(nDevRows+devWorldSize); all_J.SetSize(J_allsz); all_J = 0.0; all_A.SetSize(all_NNZ); } GatherArray(loc_I, all_I, devWorldSize, devWorld); GatherArray(loc_J, all_J, devWorldSize, devWorld); GatherArray(loc_A, all_A, devWorldSize, devWorld); MPI_Barrier(devWorld); int local_nnz(0); int64_t local_rows(0); if (myDevWorldRank == 0) { // A fix up step is needed for the array holding row data to remove extra // zeros when consolidating team data. Array z_ind(devWorldSize+1); int iter = 1; while (iter < devWorldSize-1) { // Determine the indices of zeros in global all_I array int counter = 0; z_ind[counter] = counter; counter++; for (int idx=1; idx rowPart; if (gpuProc == 0) { rowPart.SetSize(gpuWorldSize+1); rowPart=0; MPI_Allgather(&local_rows, 1, MPI_INT64_T, &rowPart.GetData()[1], 1, MPI_INT64_T, gpuWorld); MPI_Barrier(gpuWorld); // Fixup step for (int i=1; i(&op)) { SetMatrix(*Aptr); } #ifdef MFEM_USE_MPI else if (const HypreParMatrix* Aptr = dynamic_cast(&op)) { SetMatrix(*Aptr); } #endif else { mfem_error("Unsupported Operator Type \n"); } } void AmgXSolver::UpdateOperator(const Operator& op) { if (const SparseMatrix* Aptr = dynamic_cast(&op)) { SetMatrix(*Aptr, true); } #ifdef MFEM_USE_MPI else if (const HypreParMatrix* Aptr = dynamic_cast(&op)) { SetMatrix(*Aptr, true); } #endif else { mfem_error("Unsupported Operator Type \n"); } } void AmgXSolver::Mult(const Vector& B, Vector& X) const { // Set initial guess to zero X.UseDevice(true); if (!iterative_mode) { X = 0.0; } // Mult for serial, and mpi-exclusive modes if (mpi_gpu_mode != "mpi-teams") { AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, X.Size(), 1, X.ReadWrite())); AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, B.Size(), 1, B.Read())); if (mpi_gpu_mode != "serial") { #ifdef MFEM_USE_MPI MPI_Barrier(gpuWorld); #endif } AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP)); AMGX_SOLVE_STATUS status; AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status)); if (status != AMGX_SOLVE_SUCCESS && ConvergenceCheck) { if (status == AMGX_SOLVE_DIVERGED) { mfem_error("AmgX solver diverged \n"); } else { mfem_error("AmgX solver failed to solve system \n"); } } AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, X.Write())); return; } #ifdef MFEM_USE_MPI Vector all_X(mat_local_rows); Vector all_B(mat_local_rows); Array Apart_X(devWorldSize); Array Adisp_X(devWorldSize); Array Apart_B(devWorldSize); Array Adisp_B(devWorldSize); GatherArray(X, all_X, devWorldSize, devWorld, Apart_X, Adisp_X); GatherArray(B, all_B, devWorldSize, devWorld, Apart_B, Adisp_B); MPI_Barrier(devWorld); if (gpuWorld != MPI_COMM_NULL) { AMGX_SAFE_CALL(AMGX_vector_upload(AmgXP, all_X.Size(), 1, all_X.ReadWrite())); AMGX_SAFE_CALL(AMGX_vector_upload(AmgXRHS, all_B.Size(), 1, all_B.ReadWrite())); MPI_Barrier(gpuWorld); AMGX_SAFE_CALL(AMGX_solver_solve(solver,AmgXRHS, AmgXP)); AMGX_SOLVE_STATUS status; AMGX_SAFE_CALL(AMGX_solver_get_status(solver, &status)); if (status != AMGX_SOLVE_SUCCESS && amgxMode == SOLVER) { if (status == AMGX_SOLVE_DIVERGED) { mfem_error("AmgX solver diverged \n"); } else { mfem_error("AmgX solver failed to solve system \n"); } } AMGX_SAFE_CALL(AMGX_vector_download(AmgXP, all_X.Write())); } ScatterArray(all_X, X, devWorldSize, devWorld, Apart_X, Adisp_X); #endif } int AmgXSolver::GetNumIterations() { int getIters; AMGX_SAFE_CALL(AMGX_solver_get_iterations_number(solver, &getIters)); return getIters; } void AmgXSolver::Finalize() { // Check instance is initialized if (! isInitialized || count < 1) { mfem_error("Error in AmgXSolver::Finalize(). \n" "This AmgXWrapper has not been initialized. \n" "Please initialize it before finalization.\n"); } // Only processes using GPU are required to destroy AmgX content #ifdef MFEM_USE_MPI if (gpuProc == 0 || mpi_gpu_mode == "serial") #endif { // Destroy solver instance AMGX_SAFE_CALL(AMGX_solver_destroy(solver)); // Destroy matrix instance AMGX_SAFE_CALL(AMGX_matrix_destroy(AmgXA)); // Destroy RHS and unknown vectors AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXP)); AMGX_SAFE_CALL(AMGX_vector_destroy(AmgXRHS)); // Only the last instance need to destroy resource and finalizing AmgX if (count == 1) { AMGX_SAFE_CALL(AMGX_resources_destroy(rsrc)); AMGX_SAFE_CALL(AMGX_config_destroy(cfg)); AMGX_SAFE_CALL(AMGX_finalize_plugins()); AMGX_SAFE_CALL(AMGX_finalize()); } else { AMGX_SAFE_CALL(AMGX_config_destroy(cfg)); } #ifdef MFEM_USE_MPI // destroy gpuWorld if (mpi_gpu_mode != "serial") { MPI_Comm_free(&gpuWorld); } #endif } // reset necessary variables in case users want to reuse the variable of // this instance for a new instance #ifdef MFEM_USE_MPI gpuProc = MPI_UNDEFINED; if (globalCpuWorld != MPI_COMM_NULL) { MPI_Comm_free(&globalCpuWorld); MPI_Comm_free(&localCpuWorld); MPI_Comm_free(&devWorld); } #endif // decrease the number of instances count -= 1; // change status isInitialized = false; } } // mfem namespace #endif