LAPACK++  2022.07.00
LAPACK C++ API
device.hh
1 // Copyright (c) 2017-2020, University of Tennessee. All rights reserved.
2 // SPDX-License-Identifier: BSD-3-Clause
3 // This program is free software: you can redistribute it and/or modify it under
4 // the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5 
6 #ifndef LAPACK_DEVICE_HH
7 #define LAPACK_DEVICE_HH
8 
9 #include "blas/device.hh"
10 #include "lapack/util.hh"
11 
12 #if defined(LAPACK_HAVE_CUBLAS)
13  #include <cusolverDn.h>
14 #endif
15 
16 namespace lapack {
17 
18 // Since we pass pointers to these integers, their types have to match
19 // the vendor libraries.
20 #if defined(LAPACK_HAVE_CUBLAS)
21  typedef int device_info_int;
22  #if CUSOLVER_VERSION >= 11000
23  typedef int64_t device_pivot_int;
24  #else
25  typedef int device_pivot_int;
26  #endif
27 
28 #elif defined(LAPACK_HAVE_ROCBLAS)
29  typedef rocblas_int device_info_int;
30  typedef rocblas_int device_pivot_int;
31 
32 #else
33  typedef int64_t device_info_int;
34  typedef int64_t device_pivot_int;
35 #endif
36 
37 //------------------------------------------------------------------------------
38 class Queue: public blas::Queue
39 {
40 public:
41  Queue( int device=-1, int64_t batch_size=30000 )
42  : blas::Queue( device, batch_size )
43  #if defined(LAPACK_HAVE_CUBLAS)
44  , solver_( nullptr )
45  #if CUSOLVER_VERSION >= 11000
46  , solver_params_( nullptr )
47  #endif
48  #endif
49  {}
50 
51  ~Queue()
52  {
53  #if defined(LAPACK_HAVE_CUBLAS)
54  blas::set_device( device() );
55  #if CUSOLVER_VERSION >= 11000
56  if (solver_params_) {
57  cusolverDnDestroyParams( solver_params_ );
58  solver_params_ = nullptr;
59  }
60  #endif
61 
62  if (solver_) {
63  cusolverDnDestroy( solver_ );
64  solver_ = nullptr;
65  }
66  #endif
67  }
68 
69  // Disable copying; must construct anew.
70  Queue( Queue const& ) = delete;
71  Queue& operator=( Queue const& ) = delete;
72 
73  #if defined(LAPACK_HAVE_CUBLAS)
74  cusolverDnHandle_t solver()
76  {
77  if (solver_ == nullptr) {
78  blas::set_device( device() );
79  // todo: error handler
80  cusolverStatus_t status;
81  status = cusolverDnCreate( &solver_ );
82  assert( status == CUSOLVER_STATUS_SUCCESS );
83 
84  assert( stream() != nullptr );
85  status = cusolverDnSetStream( solver_, stream() );
86  assert( status == CUSOLVER_STATUS_SUCCESS );
87  }
88  return solver_;
89  }
90 
91  #if CUSOLVER_VERSION >= 11000
92  cusolverDnParams_t solver_params()
94  {
95  if (solver_params_ == nullptr) {
96  blas::set_device( device() );
97  // todo: error handler
98  cusolverStatus_t status;
99  status = cusolverDnCreateParams( &solver_params_ );
100  assert( status == CUSOLVER_STATUS_SUCCESS );
101  }
102  return solver_params_;
103  }
104  #endif
105  #endif
106 
107 private:
108  #if defined(LAPACK_HAVE_CUBLAS)
109  cusolverDnHandle_t solver_;
110  #if CUSOLVER_VERSION >= 11000
111  cusolverDnParams_t solver_params_;
112  #endif
113  #endif
114 };
115 
116 //------------------------------------------------------------------------------
117 template <typename scalar_t>
118 void potrf(
119  lapack::Uplo uplo, int64_t n,
120  scalar_t* dA, int64_t ldda,
121  device_info_int* dev_info, lapack::Queue& queue );
122 
123 //------------------------------------------------------------------------------
124 template <typename scalar_t>
125 void getrf_work_size_bytes(
126  int64_t m, int64_t n,
127  scalar_t* dA, int64_t ldda,
128  size_t* dev_work_size, size_t* host_work_size,
129  lapack::Queue& queue );
130 
131 template <typename scalar_t>
132 void getrf(
133  int64_t m, int64_t n,
134  scalar_t* dA, int64_t ldda, device_pivot_int* dev_ipiv,
135  void* dev_work, size_t dev_work_size,
136  void* host_work, size_t host_work_size,
137  device_info_int* dev_info, lapack::Queue& queue );
138 
139 //------------------------------------------------------------------------------
140 template <typename scalar_t>
141 void geqrf_work_size_bytes(
142  int64_t m, int64_t n,
143  scalar_t* dA, int64_t ldda,
144  size_t* dev_work_size, size_t* host_work_size,
145  lapack::Queue& queue );
146 
147 template <typename scalar_t>
148 void geqrf(
149  int64_t m, int64_t n,
150  scalar_t* dA, int64_t ldda, scalar_t* dtau,
151  void* dev_work, size_t dev_work_size,
152  void* host_work, size_t host_work_size,
153  device_info_int* dev_info, lapack::Queue& queue );
154 
155 } // namespace lapack
156 
157 #endif // LAPACK_DEVICE_HH