LAPACK++  2022.07.00
LAPACK C++ API
cuda_common.hh
1 // Copyright (c) 2017-2020, University of Tennessee. All rights reserved.
2 // SPDX-License-Identifier: BSD-3-Clause
3 // This program is free software: you can redistribute it and/or modify it under
4 // the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5 
6 #ifndef LAPACK_CUDA_COMMON_HH
7 #define LAPACK_CUDA_COMMON_HH
8 
9 #include "lapack/device.hh"
10 
11 //==============================================================================
12 namespace lapack {
13 
14 //------------------------------------------------------------------------------
16 template <typename scalar_t>
17 class CudaTraits;
18 
19 //----------
20 // specializations
21 template<>
22 class CudaTraits< float > {
23 public:
24  static constexpr cudaDataType datatype = CUDA_R_32F;
25 };
26 
27 //----------
28 template<>
29 class CudaTraits< double > {
30 public:
31  static constexpr cudaDataType datatype = CUDA_R_64F;
32 };
33 
34 //----------
35 template<>
36 class CudaTraits< std::complex<float> > {
37 public:
38  static constexpr cudaDataType datatype = CUDA_C_32F;
39 };
40 
41 //----------
42 template<>
43 class CudaTraits< std::complex<double> > {
44 public:
45  static constexpr cudaDataType datatype = CUDA_C_64F;
46 };
47 
48 } // namespace lapack
49 
50 //==============================================================================
51 // Inject is_device_error and device_error_string into blas namespace
52 // for blas_dev_call macros.
53 // See blaspp/include/blas/device.hh
54 namespace blas {
55 
56 inline bool is_device_error( cusolverStatus_t status )
57 {
58  return (status != CUSOLVER_STATUS_SUCCESS);
59 }
60 
61 const char* device_error_string( cusolverStatus_t error );
62 
63 } // namespace blas
64 
65 #endif // LAPACK_CUDA_COMMON_HH
lapack::CudaTraits
CudaTraits<scalar_t>::datatype maps scalar_t to cudaDataType.
Definition: cuda_common.hh:17