/* Copyright 2021 The JAX Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "jaxlib/cpu/lapack_kernels.h" // From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but // a C++ user should link against LAPACK directly. This is needed when using // JAX-generated HLO from C++. extern "C" { jax::Trsm::FnType strsm_; jax::Trsm::FnType dtrsm_; jax::Trsm>::FnType ctrsm_; jax::Trsm>::FnType ztrsm_; jax::Getrf::FnType sgetrf_; jax::Getrf::FnType dgetrf_; jax::Getrf>::FnType cgetrf_; jax::Getrf>::FnType zgetrf_; jax::Geqrf::FnType sgeqrf_; jax::Geqrf::FnType dgeqrf_; jax::Geqrf>::FnType cgeqrf_; jax::Geqrf>::FnType zgeqrf_; jax::Orgqr::FnType sorgqr_; jax::Orgqr::FnType dorgqr_; jax::Orgqr>::FnType cungqr_; jax::Orgqr>::FnType zungqr_; jax::Potrf::FnType spotrf_; jax::Potrf::FnType dpotrf_; jax::Potrf>::FnType cpotrf_; jax::Potrf>::FnType zpotrf_; jax::RealGesdd::FnType sgesdd_; jax::RealGesdd::FnType dgesdd_; jax::ComplexGesdd>::FnType cgesdd_; jax::ComplexGesdd>::FnType zgesdd_; jax::RealSyevd::FnType ssyevd_; jax::RealSyevd::FnType dsyevd_; jax::ComplexHeevd>::FnType cheevd_; jax::ComplexHeevd>::FnType zheevd_; jax::RealGeev::FnType sgeev_; jax::RealGeev::FnType dgeev_; jax::ComplexGeev>::FnType cgeev_; jax::ComplexGeev>::FnType zgeev_; jax::RealGees::FnType sgees_; jax::RealGees::FnType dgees_; jax::ComplexGees>::FnType cgees_; jax::ComplexGees>::FnType zgees_; jax::Gehrd::FnType sgehrd_; jax::Gehrd::FnType dgehrd_; jax::Gehrd>::FnType cgehrd_; jax::Gehrd>::FnType zgehrd_; jax::Sytrd::FnType ssytrd_; jax::Sytrd::FnType dsytrd_; jax::Sytrd>::FnType chetrd_; jax::Sytrd>::FnType zhetrd_; } // extern "C" namespace jax { static auto init = []() -> int { Trsm::fn = strsm_; Trsm::fn = dtrsm_; Trsm>::fn = ctrsm_; Trsm>::fn = ztrsm_; Getrf::fn = sgetrf_; Getrf::fn = dgetrf_; Getrf>::fn = cgetrf_; Getrf>::fn = zgetrf_; Geqrf::fn = sgeqrf_; Geqrf::fn = dgeqrf_; Geqrf>::fn = cgeqrf_; Geqrf>::fn = zgeqrf_; Orgqr::fn = sorgqr_; Orgqr::fn = dorgqr_; Orgqr>::fn = cungqr_; Orgqr>::fn = zungqr_; Potrf::fn = spotrf_; Potrf::fn = dpotrf_; Potrf>::fn = cpotrf_; Potrf>::fn = zpotrf_; RealGesdd::fn = sgesdd_; RealGesdd::fn = dgesdd_; ComplexGesdd>::fn = cgesdd_; ComplexGesdd>::fn = zgesdd_; RealSyevd::fn = ssyevd_; RealSyevd::fn = dsyevd_; ComplexHeevd>::fn = cheevd_; ComplexHeevd>::fn = zheevd_; RealGeev::fn = sgeev_; RealGeev::fn = dgeev_; ComplexGeev>::fn = cgeev_; ComplexGeev>::fn = zgeev_; RealGees::fn = sgees_; RealGees::fn = dgees_; ComplexGees>::fn = cgees_; ComplexGees>::fn = zgees_; Gehrd::fn = sgehrd_; Gehrd::fn = dgehrd_; Gehrd>::fn = cgehrd_; Gehrd>::fn = zgehrd_; Sytrd::fn = ssytrd_; Sytrd::fn = dsytrd_; Sytrd>::fn = chetrd_; Sytrd>::fn = zhetrd_; return 0; }(); } // namespace jax