6 #ifndef LAPACK_DEVICE_HH
7 #define LAPACK_DEVICE_HH
9 #include "blas/device.hh"
10 #include "lapack/util.hh"
12 #if defined(LAPACK_HAVE_CUBLAS)
13 #include <cusolverDn.h>
20 #if defined(LAPACK_HAVE_CUBLAS)
21 typedef int device_info_int;
22 #if CUSOLVER_VERSION >= 11000
23 typedef int64_t device_pivot_int;
25 typedef int device_pivot_int;
28 #elif defined(LAPACK_HAVE_ROCBLAS)
29 typedef rocblas_int device_info_int;
30 typedef rocblas_int device_pivot_int;
33 typedef int64_t device_info_int;
34 typedef int64_t device_pivot_int;
38 class Queue:
public blas::Queue
41 Queue(
int device=-1, int64_t batch_size=30000 )
42 : blas::Queue( device, batch_size )
43 #if defined(LAPACK_HAVE_CUBLAS)
45 #if CUSOLVER_VERSION >= 11000
46 , solver_params_( nullptr )
53 #if defined(LAPACK_HAVE_CUBLAS)
54 blas::set_device( device() );
55 #if CUSOLVER_VERSION >= 11000
57 cusolverDnDestroyParams( solver_params_ );
58 solver_params_ =
nullptr;
63 cusolverDnDestroy( solver_ );
70 Queue( Queue
const& ) =
delete;
71 Queue& operator=( Queue
const& ) =
delete;
73 #if defined(LAPACK_HAVE_CUBLAS)
74 cusolverDnHandle_t solver()
77 if (solver_ ==
nullptr) {
78 blas::set_device( device() );
80 cusolverStatus_t status;
81 status = cusolverDnCreate( &solver_ );
82 assert( status == CUSOLVER_STATUS_SUCCESS );
84 assert( stream() !=
nullptr );
85 status = cusolverDnSetStream( solver_, stream() );
86 assert( status == CUSOLVER_STATUS_SUCCESS );
91 #if CUSOLVER_VERSION >= 11000
92 cusolverDnParams_t solver_params()
95 if (solver_params_ ==
nullptr) {
96 blas::set_device( device() );
98 cusolverStatus_t status;
99 status = cusolverDnCreateParams( &solver_params_ );
100 assert( status == CUSOLVER_STATUS_SUCCESS );
102 return solver_params_;
108 #if defined(LAPACK_HAVE_CUBLAS)
109 cusolverDnHandle_t solver_;
110 #if CUSOLVER_VERSION >= 11000
111 cusolverDnParams_t solver_params_;
117 template <
typename scalar_t>
119 lapack::Uplo uplo, int64_t n,
120 scalar_t* dA, int64_t ldda,
121 device_info_int* dev_info, lapack::Queue& queue );
124 template <
typename scalar_t>
125 void getrf_work_size_bytes(
126 int64_t m, int64_t n,
127 scalar_t* dA, int64_t ldda,
128 size_t* dev_work_size,
size_t* host_work_size,
129 lapack::Queue& queue );
131 template <
typename scalar_t>
133 int64_t m, int64_t n,
134 scalar_t* dA, int64_t ldda, device_pivot_int* dev_ipiv,
135 void* dev_work,
size_t dev_work_size,
136 void* host_work,
size_t host_work_size,
137 device_info_int* dev_info, lapack::Queue& queue );
140 template <
typename scalar_t>
141 void geqrf_work_size_bytes(
142 int64_t m, int64_t n,
143 scalar_t* dA, int64_t ldda,
144 size_t* dev_work_size,
size_t* host_work_size,
145 lapack::Queue& queue );
147 template <
typename scalar_t>
149 int64_t m, int64_t n,
150 scalar_t* dA, int64_t ldda, scalar_t* dtau,
151 void* dev_work,
size_t dev_work_size,
152 void* host_work,
size_t host_work_size,
153 device_info_int* dev_info, lapack::Queue& queue );
157 #endif // LAPACK_DEVICE_HH