vSMC  v3.0.0
Scalable Monte Carlo
cblas.hpp
Go to the documentation of this file.
1 //============================================================================
2 // vSMC/include/vsmc/math/cblas.hpp
3 //----------------------------------------------------------------------------
4 // vSMC: Scalable Monte Carlo
5 //----------------------------------------------------------------------------
6 // Copyright (c) 2013-2016, Yan Zhou
7 // All rights reserved.
8 //
9 // Redistribution and use in source and binary forms, with or without
10 // modification, are permitted provided that the following conditions are met:
11 //
12 // Redistributions of source code must retain the above copyright notice,
13 // this list of conditions and the following disclaimer.
14 //
15 // Redistributions in binary form must reproduce the above copyright notice,
16 // this list of conditions and the following disclaimer in the documentation
17 // and/or other materials provided with the distribution.
18 //
19 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS AS IS
20 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29 // POSSIBILITY OF SUCH DAMAGE.
30 //============================================================================
31 
32 #ifndef VSMC_MATH_CBLAS_HPP
33 #define VSMC_MATH_CBLAS_HPP
34 
35 #include <vsmc/internal/config.h>
36 
37 #if VSMC_USE_CBLAS
38 
39 #if VSMC_USE_MKL_CBLAS
40 #include <mkl_cblas.h>
41 #ifndef VSMC_BLAS_INT
42 #define VSMC_BLAS_INT MKL_INT
43 #endif
44 #elif VSMC_USE_ACCELERATE
45 #include <Accelerate/Accelerate.h>
46 #else
47 #include <cblas.h>
48 #endif
49 
50 #ifndef VSMC_BLAS_INT
51 #define VSMC_BLAS_INT int
52 #endif
53 
54 namespace vsmc
55 {
56 
57 namespace internal
58 {
59 
60 using ::CblasRowMajor;
61 using ::CblasColMajor;
62 
63 using ::CblasNoTrans;
64 using ::CblasTrans;
65 using ::CblasConjTrans;
66 
67 using ::CblasUpper;
68 using ::CblasLower;
69 
70 using ::CblasNonUnit;
71 using ::CblasUnit;
72 
73 using ::CblasLeft;
74 using ::CblasRight;
75 
76 using ::cblas_sdot;
77 using ::cblas_ddot;
78 
79 using ::cblas_sgemv;
80 using ::cblas_dgemv;
81 
82 using ::cblas_stpmv;
83 using ::cblas_dtpmv;
84 
85 using ::cblas_ssyr;
86 using ::cblas_dsyr;
87 
88 using ::cblas_strmm;
89 using ::cblas_dtrmm;
90 
91 using ::cblas_ssyrk;
92 using ::cblas_dsyrk;
93 
94 } // namespace vsmc::internal
95 
96 } // namespace vsmc
97 
98 #else // VSMC_USE_CBLAS
99 
100 #ifndef VSMC_BLAS_NAME
101 #ifdef VSMC_BLAS_NAME_NO_UNDERSCORE
102 #define VSMC_BLAS_NAME(x) x
103 #else
104 #define VSMC_BLAS_NAME(x) x##_
105 #endif
106 #endif
107 
108 #ifndef VSMC_BLAS_INT
109 #define VSMC_BLAS_INT int
110 #endif
111 
112 extern "C" {
113 
114 void VSMC_BLAS_NAME(sgemv)(const char *trans, const VSMC_BLAS_INT *m,
115  const VSMC_BLAS_INT *n, const float *alpha, const float *a,
116  const VSMC_BLAS_INT *lda, const float *x, const VSMC_BLAS_INT *incx,
117  const float *beta, float *y, const VSMC_BLAS_INT *incy);
118 
119 void VSMC_BLAS_NAME(dgemv)(const char *trans, const VSMC_BLAS_INT *m,
120  const VSMC_BLAS_INT *n, const double *alpha, const double *a,
121  const VSMC_BLAS_INT *lda, const double *x, const VSMC_BLAS_INT *incx,
122  const double *beta, double *y, const VSMC_BLAS_INT *incy);
123 
124 void VSMC_BLAS_NAME(stpmv)(const char *uplo, const char *trans,
125  const char *diag, const VSMC_BLAS_INT *n, const float *ap, float *x,
126  const VSMC_BLAS_INT *incx);
127 
128 void VSMC_BLAS_NAME(dtpmv)(const char *uplo, const char *trans,
129  const char *diag, const VSMC_BLAS_INT *n, const double *ap, double *x,
130  const VSMC_BLAS_INT *incx);
131 
132 void VSMC_BLAS_NAME(ssyr)(const char *uplo, const VSMC_BLAS_INT *n,
133  const float *alpha, const float *x, const VSMC_BLAS_INT *incx, float *a,
134  const VSMC_BLAS_INT *lda);
135 
136 void VSMC_BLAS_NAME(dsyr)(const char *uplo, const VSMC_BLAS_INT *n,
137  const double *alpha, const double *x, const VSMC_BLAS_INT *incx, double *a,
138  const VSMC_BLAS_INT *lda);
139 
140 void VSMC_BLAS_NAME(strmm)(const char *side, const char *uplo,
141  const char *transa, const char *diag, const VSMC_BLAS_INT *m,
142  const VSMC_BLAS_INT *n, const float *alpha, const float *a,
143  const VSMC_BLAS_INT *lda, float *b, const VSMC_BLAS_INT *ldb);
144 
145 void VSMC_BLAS_NAME(dtrmm)(const char *side, const char *uplo,
146  const char *transa, const char *diag, const VSMC_BLAS_INT *m,
147  const VSMC_BLAS_INT *n, const double *alpha, const double *a,
148  const VSMC_BLAS_INT *lda, double *b, const VSMC_BLAS_INT *ldb);
149 
150 void VSMC_BLAS_NAME(ssyrk)(const char *uplo, const char *trans,
151  const VSMC_BLAS_INT *n, const VSMC_BLAS_INT *k, const float *alpha,
152  const float *a, const VSMC_BLAS_INT *lda, const float *beta, float *c,
153  const VSMC_BLAS_INT *ldc);
154 
155 void VSMC_BLAS_NAME(dsyrk)(const char *uplo, const char *trans,
156  const VSMC_BLAS_INT *n, const VSMC_BLAS_INT *k, const double *alpha,
157  const double *a, const VSMC_BLAS_INT *lda, const double *beta, double *c,
158  const VSMC_BLAS_INT *ldc);
159 
160 } // extern "C"
161 
162 namespace vsmc
163 {
164 
165 namespace internal
166 {
167 
168 enum CBLAS_LAYOUT { CblasRowMajor = 101, CblasColMajor = 102 };
169 
170 using CBLAS_ORDER = CBLAS_LAYOUT;
171 
172 enum CBLAS_TRANSPOSE {
173  CblasNoTrans = 111,
174  CblasTrans = 112,
175  CblasConjTrans = 113
176 };
177 
178 enum CBLAS_UPLO { CblasUpper = 121, CblasLower = 122 };
179 
180 enum CBLAS_DIAG { CblasNonUnit = 131, CblasUnit = 132 };
181 
182 enum CBLAS_SIDE { CblasLeft = 141, CblasRight = 142 };
183 
184 inline char cblas_trans(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE trans)
185 {
186  if (layout == CblasColMajor) {
187  if (trans == CblasNoTrans)
188  return 'N';
189  else if (trans == CblasTrans)
190  return 'T';
191  else if (trans == CblasConjTrans)
192  return 'C';
193  }
194  if (layout == CblasRowMajor) {
195  if (trans == CblasNoTrans)
196  return 'T';
197  else if (trans == CblasTrans)
198  return 'N';
199  else if (trans == CblasConjTrans)
200  return 'N';
201  }
202  return 'N';
203 }
204 
205 inline char cblas_uplo(const CBLAS_LAYOUT layout, const CBLAS_UPLO uplo)
206 {
207  if (layout == CblasColMajor) {
208  if (uplo == CblasUpper)
209  return 'U';
210  if (uplo == CblasLower)
211  return 'L';
212  }
213  if (layout == CblasRowMajor) {
214  if (uplo == CblasUpper)
215  return 'L';
216  if (uplo == CblasLower)
217  return 'U';
218  }
219  return 'U';
220 }
221 
222 inline char cblas_diag(const CBLAS_DIAG diag)
223 {
224  if (diag == CblasUnit)
225  return 'U';
226  if (diag == CblasNonUnit)
227  return 'N';
228  return 'N';
229 }
230 
231 inline char cblas_side(const CBLAS_LAYOUT layout, const CBLAS_SIDE side)
232 {
233  if (layout == CblasColMajor) {
234  if (side == CblasLeft)
235  return 'L';
236  if (side == CblasRight)
237  return 'R';
238  }
239  if (layout == CblasRowMajor) {
240  if (side == CblasLeft)
241  return 'R';
242  if (side == CblasRight)
243  return 'L';
244  }
245  return 'L';
246 }
247 
248 inline float cblas_sdot(const VSMC_BLAS_INT n, const float *x,
249  const VSMC_BLAS_INT incx, const float *y, const VSMC_BLAS_INT incy)
250 {
251  float s = 0;
252  for (VSMC_BLAS_INT i = 0; i != n; ++i, x += incx, y += incy)
253  s += (*x) * (*y);
254 
255  return s;
256 }
257 
258 inline double cblas_ddot(const VSMC_BLAS_INT n, const double *x,
259  const VSMC_BLAS_INT incx, const double *y, const VSMC_BLAS_INT incy)
260 {
261  double s = 0;
262  for (VSMC_BLAS_INT i = 0; i != n; ++i, x += incx, y += incy)
263  s += (*x) * (*y);
264 
265  return s;
266 }
267 
268 inline void cblas_sgemv(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE trans,
269  const VSMC_BLAS_INT m, const VSMC_BLAS_INT n, const float alpha,
270  const float *a, const VSMC_BLAS_INT lda, const float *x,
271  const VSMC_BLAS_INT incx, const float beta, float *y,
272  const VSMC_BLAS_INT incy)
273 {
274  const char transf = cblas_trans(layout, trans);
275  VSMC_BLAS_NAME(sgemv)
276  (&transf, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
277 }
278 
279 inline void cblas_dgemv(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE trans,
280  const VSMC_BLAS_INT m, const VSMC_BLAS_INT n, const double alpha,
281  const double *a, const VSMC_BLAS_INT lda, const double *x,
282  const VSMC_BLAS_INT incx, const double beta, double *y,
283  const VSMC_BLAS_INT incy)
284 {
285  const char transf = cblas_trans(layout, trans);
286  VSMC_BLAS_NAME(dgemv)
287  (&transf, &m, &n, &alpha, a, &lda, x, &incx, &beta, y, &incy);
288 }
289 
290 inline void cblas_stpmv(const CBLAS_LAYOUT layout, const CBLAS_UPLO uplo,
291  const CBLAS_TRANSPOSE trans, const CBLAS_DIAG diag, const VSMC_BLAS_INT n,
292  const float *ap, float *x, const VSMC_BLAS_INT incx)
293 {
294  const char uplof = cblas_uplo(layout, uplo);
295  const char transf = cblas_trans(layout, trans);
296  const char diagf = cblas_diag(diag);
297  VSMC_BLAS_NAME(stpmv)(&uplof, &transf, &diagf, &n, ap, x, &incx);
298 }
299 
300 inline void cblas_dtpmv(CBLAS_LAYOUT layout, const CBLAS_UPLO uplo,
301  const CBLAS_TRANSPOSE trans, const CBLAS_DIAG diag, const VSMC_BLAS_INT n,
302  const double *ap, double *x, const VSMC_BLAS_INT incx)
303 {
304  const char uplof = cblas_uplo(layout, uplo);
305  const char transf = cblas_trans(layout, trans);
306  const char diagf = cblas_diag(diag);
307  VSMC_BLAS_NAME(dtpmv)(&uplof, &transf, &diagf, &n, ap, x, &incx);
308 }
309 
310 inline void cblas_ssyr(const CBLAS_LAYOUT layout, const CBLAS_UPLO uplo,
311  const VSMC_BLAS_INT n, const float alpha, const float *x,
312  const VSMC_BLAS_INT incx, float *a, const VSMC_BLAS_INT lda)
313 {
314  const char uplof = cblas_uplo(layout, uplo);
315  VSMC_BLAS_NAME(ssyr)(&uplof, &n, &alpha, x, &incx, a, &lda);
316 }
317 
318 inline void cblas_dsyr(const CBLAS_LAYOUT layout, const CBLAS_UPLO uplo,
319  const VSMC_BLAS_INT n, const double alpha, const double *x,
320  const VSMC_BLAS_INT incx, double *a, const VSMC_BLAS_INT lda)
321 {
322  const char uplof = cblas_uplo(layout, uplo);
323  VSMC_BLAS_NAME(dsyr)(&uplof, &n, &alpha, x, &incx, a, &lda);
324 }
325 
326 inline void cblas_strmm(const CBLAS_LAYOUT layout, const CBLAS_SIDE side,
327  const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE trans, const CBLAS_DIAG diag,
328  const VSMC_BLAS_INT m, const VSMC_BLAS_INT n, const float alpha,
329  const float *a, const VSMC_BLAS_INT lda, float *b, const VSMC_BLAS_INT ldb)
330 {
331  const char sidef = cblas_side(layout, side);
332  const char uplof = cblas_uplo(layout, uplo);
333  const char transf = cblas_trans(CblasColMajor, trans);
334  const char diagf = cblas_diag(diag);
335  VSMC_BLAS_NAME(strmm)
336  (&sidef, &uplof, &transf, &diagf, &m, &n, &alpha, a, &lda, b, &ldb);
337 }
338 
339 inline void cblas_dtrmm(const CBLAS_LAYOUT layout, const CBLAS_SIDE side,
340  const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE trans, const CBLAS_DIAG diag,
341  const VSMC_BLAS_INT m, const VSMC_BLAS_INT n, const double alpha,
342  const double *a, const VSMC_BLAS_INT lda, double *b,
343  const VSMC_BLAS_INT ldb)
344 {
345  const char sidef = cblas_side(layout, side);
346  const char uplof = cblas_uplo(layout, uplo);
347  const char transf = cblas_trans(CblasColMajor, trans);
348  const char diagf = cblas_diag(diag);
349  VSMC_BLAS_NAME(dtrmm)
350  (&sidef, &uplof, &transf, &diagf, &m, &n, &alpha, a, &lda, b, &ldb);
351 }
352 
353 inline void cblas_ssyrk(const CBLAS_LAYOUT layout, const CBLAS_UPLO uplo,
354  const CBLAS_TRANSPOSE trans, const VSMC_BLAS_INT n, const VSMC_BLAS_INT k,
355  const float alpha, const float *a, const VSMC_BLAS_INT lda,
356  const float beta, float *c, const VSMC_BLAS_INT ldc)
357 {
358  const char uplof = cblas_uplo(layout, uplo);
359  const char transf = cblas_trans(layout, trans);
360  VSMC_BLAS_NAME(ssyrk)
361  (&uplof, &transf, &n, &k, &alpha, a, &lda, &beta, c, &ldc);
362 }
363 
364 inline void cblas_dsyrk(const CBLAS_LAYOUT layout, const CBLAS_UPLO uplo,
365  const CBLAS_TRANSPOSE trans, const VSMC_BLAS_INT n, const VSMC_BLAS_INT k,
366  const double alpha, const double *a, const VSMC_BLAS_INT lda,
367  const double beta, double *c, const VSMC_BLAS_INT ldc)
368 {
369  const char uplof = cblas_uplo(layout, uplo);
370  const char transf = cblas_trans(layout, trans);
371  VSMC_BLAS_NAME(dsyrk)
372  (&uplof, &transf, &n, &k, &alpha, a, &lda, &beta, c, &ldc);
373 }
374 
375 } // namespace vsmc::internal
376 
377 } // namespace vsmc
378 
379 #endif // VSMC_USE_CBLAS
380 
381 #endif // VSMC_MATH_CBLAS_HPP
Definition: monitor.hpp:48
#define VSMC_BLAS_INT
Definition: cblas.hpp:42