1
|
|
/* -*- c -*- */
|
2
|
|
|
3
|
|
#define _UMATHMODULE
|
4
|
|
#define _MULTIARRAYMODULE
|
5
|
|
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
|
6
|
|
|
7
|
|
#include "Python.h"
|
8
|
|
|
9
|
|
#include "npy_config.h"
|
10
|
|
#include "numpy/npy_common.h"
|
11
|
|
#include "numpy/arrayobject.h"
|
12
|
|
#include "numpy/ufuncobject.h"
|
13
|
|
#include "numpy/npy_math.h"
|
14
|
|
#include "numpy/halffloat.h"
|
15
|
|
#include "lowlevel_strided_loops.h"
|
16
|
|
|
17
|
|
#include "npy_pycompat.h"
|
18
|
|
|
19
|
|
#include "npy_cblas.h"
|
20
|
|
#include "arraytypes.h" /* For TYPE_dot functions */
|
21
|
|
|
22
|
|
#include <assert.h>
|
23
|
|
|
24
|
|
/*
|
25
|
|
*****************************************************************************
|
26
|
|
** BASICS **
|
27
|
|
*****************************************************************************
|
28
|
|
*/
|
29
|
|
|
30
|
|
#if defined(HAVE_CBLAS)
|
31
|
|
/*
|
32
|
|
* -1 to be conservative, in case blas internally uses a for loop with an
|
33
|
|
* inclusive upper bound
|
34
|
|
*/
|
35
|
|
#ifndef HAVE_BLAS_ILP64
|
36
|
|
#define BLAS_MAXSIZE (NPY_MAX_INT - 1)
|
37
|
|
#else
|
38
|
|
#define BLAS_MAXSIZE (NPY_MAX_INT64 - 1)
|
39
|
|
#endif
|
40
|
|
|
41
|
|
/*
|
42
|
|
* Determine if a 2d matrix can be used by BLAS
|
43
|
|
* 1. Strides must not alias or overlap
|
44
|
|
* 2. The faster (second) axis must be contiguous
|
45
|
|
* 3. The slower (first) axis stride, in unit steps, must be larger than
|
46
|
|
* the faster axis dimension
|
47
|
|
*/
|
48
|
|
static NPY_INLINE npy_bool
|
49
|
|
is_blasable2d(npy_intp byte_stride1, npy_intp byte_stride2,
|
50
|
|
npy_intp d1, npy_intp d2, npy_intp itemsize)
|
51
|
|
{
|
52
|
1
|
npy_intp unit_stride1 = byte_stride1 / itemsize;
|
53
|
1
|
if (byte_stride2 != itemsize) {
|
54
|
|
return NPY_FALSE;
|
55
|
|
}
|
56
|
1
|
if ((byte_stride1 % itemsize ==0) &&
|
57
|
1
|
(unit_stride1 >= d2) &&
|
58
|
1
|
(unit_stride1 <= BLAS_MAXSIZE))
|
59
|
|
{
|
60
|
|
return NPY_TRUE;
|
61
|
|
}
|
62
|
|
return NPY_FALSE;
|
63
|
|
}
|
64
|
|
|
65
|
|
static const npy_cdouble oneD = {1.0, 0.0}, zeroD = {0.0, 0.0};
|
66
|
|
static const npy_cfloat oneF = {1.0, 0.0}, zeroF = {0.0, 0.0};
|
67
|
|
|
68
|
|
/**begin repeat
|
69
|
|
*
|
70
|
|
* #name = FLOAT, DOUBLE, CFLOAT, CDOUBLE#
|
71
|
|
* #ctype = npy_float, npy_double, npy_cfloat, npy_cdouble#
|
72
|
|
* #typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
|
73
|
|
* #prefix = s, d, c, z#
|
74
|
|
* #step1 = 1.F, 1., &oneF, &oneD#
|
75
|
|
* #step0 = 0.F, 0., &zeroF, &zeroD#
|
76
|
|
*/
|
77
|
|
NPY_NO_EXPORT void
|
78
|
1
|
@name@_gemv(void *ip1, npy_intp is1_m, npy_intp is1_n,
|
79
|
|
void *ip2, npy_intp is2_n, npy_intp NPY_UNUSED(is2_p),
|
80
|
|
void *op, npy_intp op_m, npy_intp NPY_UNUSED(op_p),
|
81
|
|
npy_intp m, npy_intp n, npy_intp NPY_UNUSED(p))
|
82
|
|
{
|
83
|
|
/*
|
84
|
|
* Vector matrix multiplication -- Level 2 BLAS
|
85
|
|
* arguments
|
86
|
|
* ip1: contiguous data, m*n shape
|
87
|
|
* ip2: data in c order, n*1 shape
|
88
|
|
* op: data in c order, m shape
|
89
|
|
*/
|
90
|
|
enum CBLAS_ORDER order;
|
91
|
|
CBLAS_INT M, N, lda;
|
92
|
|
|
93
|
|
assert(m <= BLAS_MAXSIZE && n <= BLAS_MAXSIZE);
|
94
|
|
assert (is_blasable2d(is2_n, sizeof(@typ@), n, 1, sizeof(@typ@)));
|
95
|
1
|
M = (CBLAS_INT)m;
|
96
|
1
|
N = (CBLAS_INT)n;
|
97
|
|
|
98
|
1
|
if (is_blasable2d(is1_m, is1_n, m, n, sizeof(@typ@))) {
|
99
|
1
|
order = CblasColMajor;
|
100
|
1
|
lda = (CBLAS_INT)(is1_m / sizeof(@typ@));
|
101
|
|
}
|
102
|
|
else {
|
103
|
|
/* If not ColMajor, caller should have ensured we are RowMajor */
|
104
|
|
/* will not assert in release mode */
|
105
|
1
|
order = CblasRowMajor;
|
106
|
|
assert(is_blasable2d(is1_n, is1_m, n, m, sizeof(@typ@)));
|
107
|
1
|
lda = (CBLAS_INT)(is1_n / sizeof(@typ@));
|
108
|
|
}
|
109
|
1
|
CBLAS_FUNC(cblas_@prefix@gemv)(order, CblasTrans, N, M, @step1@, ip1, lda, ip2,
|
110
|
1
|
is2_n / sizeof(@typ@), @step0@, op, op_m / sizeof(@typ@));
|
111
|
1
|
}
|
112
|
|
|
113
|
|
NPY_NO_EXPORT void
|
114
|
1
|
@name@_matmul_matrixmatrix(void *ip1, npy_intp is1_m, npy_intp is1_n,
|
115
|
|
void *ip2, npy_intp is2_n, npy_intp is2_p,
|
116
|
|
void *op, npy_intp os_m, npy_intp os_p,
|
117
|
|
npy_intp m, npy_intp n, npy_intp p)
|
118
|
|
{
|
119
|
|
/*
|
120
|
|
* matrix matrix multiplication -- Level 3 BLAS
|
121
|
|
*/
|
122
|
1
|
enum CBLAS_ORDER order = CblasRowMajor;
|
123
|
|
enum CBLAS_TRANSPOSE trans1, trans2;
|
124
|
|
CBLAS_INT M, N, P, lda, ldb, ldc;
|
125
|
|
assert(m <= BLAS_MAXSIZE && n <= BLAS_MAXSIZE && p <= BLAS_MAXSIZE);
|
126
|
1
|
M = (CBLAS_INT)m;
|
127
|
1
|
N = (CBLAS_INT)n;
|
128
|
1
|
P = (CBLAS_INT)p;
|
129
|
|
|
130
|
|
assert(is_blasable2d(os_m, os_p, m, p, sizeof(@typ@)));
|
131
|
1
|
ldc = (CBLAS_INT)(os_m / sizeof(@typ@));
|
132
|
|
|
133
|
1
|
if (is_blasable2d(is1_m, is1_n, m, n, sizeof(@typ@))) {
|
134
|
1
|
trans1 = CblasNoTrans;
|
135
|
1
|
lda = (CBLAS_INT)(is1_m / sizeof(@typ@));
|
136
|
|
}
|
137
|
|
else {
|
138
|
|
/* If not ColMajor, caller should have ensured we are RowMajor */
|
139
|
|
/* will not assert in release mode */
|
140
|
|
assert(is_blasable2d(is1_n, is1_m, n, m, sizeof(@typ@)));
|
141
|
1
|
trans1 = CblasTrans;
|
142
|
1
|
lda = (CBLAS_INT)(is1_n / sizeof(@typ@));
|
143
|
|
}
|
144
|
|
|
145
|
1
|
if (is_blasable2d(is2_n, is2_p, n, p, sizeof(@typ@))) {
|
146
|
1
|
trans2 = CblasNoTrans;
|
147
|
1
|
ldb = (CBLAS_INT)(is2_n / sizeof(@typ@));
|
148
|
|
}
|
149
|
|
else {
|
150
|
|
/* If not ColMajor, caller should have ensured we are RowMajor */
|
151
|
|
/* will not assert in release mode */
|
152
|
|
assert(is_blasable2d(is2_p, is2_n, p, n, sizeof(@typ@)));
|
153
|
1
|
trans2 = CblasTrans;
|
154
|
1
|
ldb = (CBLAS_INT)(is2_p / sizeof(@typ@));
|
155
|
|
}
|
156
|
|
/*
|
157
|
|
* Use syrk if we have a case of a matrix times its transpose.
|
158
|
|
* Otherwise, use gemm for all other cases.
|
159
|
|
*/
|
160
|
1
|
if (
|
161
|
1
|
(ip1 == ip2) &&
|
162
|
1
|
(m == p) &&
|
163
|
1
|
(is1_m == is2_p) &&
|
164
|
1
|
(is1_n == is2_n) &&
|
165
|
|
(trans1 != trans2)
|
166
|
|
) {
|
167
|
|
npy_intp i,j;
|
168
|
1
|
if (trans1 == CblasNoTrans) {
|
169
|
1
|
CBLAS_FUNC(cblas_@prefix@syrk)(
|
170
|
|
order, CblasUpper, trans1, P, N, @step1@,
|
171
|
|
ip1, lda, @step0@, op, ldc);
|
172
|
|
}
|
173
|
|
else {
|
174
|
1
|
CBLAS_FUNC(cblas_@prefix@syrk)(
|
175
|
|
order, CblasUpper, trans1, P, N, @step1@,
|
176
|
|
ip1, ldb, @step0@, op, ldc);
|
177
|
|
}
|
178
|
|
/* Copy the triangle */
|
179
|
1
|
for (i = 0; i < P; i++) {
|
180
|
1
|
for (j = i + 1; j < P; j++) {
|
181
|
1
|
((@typ@*)op)[j * ldc + i] = ((@typ@*)op)[i * ldc + j];
|
182
|
|
}
|
183
|
|
}
|
184
|
|
|
185
|
|
}
|
186
|
|
else {
|
187
|
1
|
CBLAS_FUNC(cblas_@prefix@gemm)(
|
188
|
|
order, trans1, trans2, M, P, N, @step1@, ip1, lda,
|
189
|
|
ip2, ldb, @step0@, op, ldc);
|
190
|
|
}
|
191
|
1
|
}
|
192
|
|
|
193
|
|
/**end repeat**/
|
194
|
|
#endif
|
195
|
|
|
196
|
|
/*
|
197
|
|
* matmul loops
|
198
|
|
* signature is (m?,n),(n,p?)->(m?,p?)
|
199
|
|
*/
|
200
|
|
|
201
|
|
/**begin repeat
|
202
|
|
* #TYPE = LONGDOUBLE,
|
203
|
|
* FLOAT, DOUBLE, HALF,
|
204
|
|
* CFLOAT, CDOUBLE, CLONGDOUBLE,
|
205
|
|
* UBYTE, USHORT, UINT, ULONG, ULONGLONG,
|
206
|
|
* BYTE, SHORT, INT, LONG, LONGLONG#
|
207
|
|
* #typ = npy_longdouble,
|
208
|
|
* npy_float,npy_double,npy_half,
|
209
|
|
* npy_cfloat, npy_cdouble, npy_clongdouble,
|
210
|
|
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
|
211
|
|
* npy_byte, npy_short, npy_int, npy_long, npy_longlong#
|
212
|
|
* #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*10#
|
213
|
|
* #IS_HALF = 0, 0, 0, 1, 0*13#
|
214
|
|
*/
|
215
|
|
|
216
|
|
NPY_NO_EXPORT void
|
217
|
1
|
@TYPE@_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
|
218
|
|
void *_ip2, npy_intp is2_n, npy_intp is2_p,
|
219
|
|
void *_op, npy_intp os_m, npy_intp os_p,
|
220
|
|
npy_intp dm, npy_intp dn, npy_intp dp)
|
221
|
|
|
222
|
|
{
|
223
|
|
npy_intp m, n, p;
|
224
|
|
npy_intp ib1_n, ib2_n, ib2_p, ob_p;
|
225
|
1
|
char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;
|
226
|
|
|
227
|
1
|
ib1_n = is1_n * dn;
|
228
|
1
|
ib2_n = is2_n * dn;
|
229
|
1
|
ib2_p = is2_p * dp;
|
230
|
1
|
ob_p = os_p * dp;
|
231
|
|
|
232
|
1
|
for (m = 0; m < dm; m++) {
|
233
|
1
|
for (p = 0; p < dp; p++) {
|
234
|
|
#if @IS_COMPLEX@ == 1
|
235
|
1
|
(*(@typ@ *)op).real = 0;
|
236
|
1
|
(*(@typ@ *)op).imag = 0;
|
237
|
|
#elif @IS_HALF@
|
238
|
|
float sum = 0;
|
239
|
|
#else
|
240
|
1
|
*(@typ@ *)op = 0;
|
241
|
|
#endif
|
242
|
1
|
for (n = 0; n < dn; n++) {
|
243
|
1
|
@typ@ val1 = (*(@typ@ *)ip1);
|
244
|
1
|
@typ@ val2 = (*(@typ@ *)ip2);
|
245
|
|
#if @IS_HALF@
|
246
|
1
|
sum += npy_half_to_float(val1) * npy_half_to_float(val2);
|
247
|
|
#elif @IS_COMPLEX@ == 1
|
248
|
1
|
(*(@typ@ *)op).real += (val1.real * val2.real) -
|
249
|
1
|
(val1.imag * val2.imag);
|
250
|
1
|
(*(@typ@ *)op).imag += (val1.real * val2.imag) +
|
251
|
1
|
(val1.imag * val2.real);
|
252
|
|
#else
|
253
|
1
|
*(@typ@ *)op += val1 * val2;
|
254
|
|
#endif
|
255
|
1
|
ip2 += is2_n;
|
256
|
1
|
ip1 += is1_n;
|
257
|
|
}
|
258
|
|
#if @IS_HALF@
|
259
|
1
|
*(@typ@ *)op = npy_float_to_half(sum);
|
260
|
|
#endif
|
261
|
1
|
ip1 -= ib1_n;
|
262
|
1
|
ip2 -= ib2_n;
|
263
|
1
|
op += os_p;
|
264
|
1
|
ip2 += is2_p;
|
265
|
|
}
|
266
|
1
|
op -= ob_p;
|
267
|
1
|
ip2 -= ib2_p;
|
268
|
1
|
ip1 += is1_m;
|
269
|
1
|
op += os_m;
|
270
|
|
}
|
271
|
1
|
}
|
272
|
|
|
273
|
|
/**end repeat**/
|
274
|
|
NPY_NO_EXPORT void
|
275
|
1
|
BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
|
276
|
|
void *_ip2, npy_intp is2_n, npy_intp is2_p,
|
277
|
|
void *_op, npy_intp os_m, npy_intp os_p,
|
278
|
|
npy_intp dm, npy_intp dn, npy_intp dp)
|
279
|
|
|
280
|
|
{
|
281
|
|
npy_intp m, n, p;
|
282
|
|
npy_intp ib2_p, ob_p;
|
283
|
1
|
char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;
|
284
|
|
|
285
|
1
|
ib2_p = is2_p * dp;
|
286
|
1
|
ob_p = os_p * dp;
|
287
|
|
|
288
|
1
|
for (m = 0; m < dm; m++) {
|
289
|
1
|
for (p = 0; p < dp; p++) {
|
290
|
1
|
char *ip1tmp = ip1;
|
291
|
1
|
char *ip2tmp = ip2;
|
292
|
1
|
*(npy_bool *)op = NPY_FALSE;
|
293
|
1
|
for (n = 0; n < dn; n++) {
|
294
|
1
|
npy_bool val1 = (*(npy_bool *)ip1tmp);
|
295
|
1
|
npy_bool val2 = (*(npy_bool *)ip2tmp);
|
296
|
1
|
if (val1 != 0 && val2 != 0) {
|
297
|
1
|
*(npy_bool *)op = NPY_TRUE;
|
298
|
1
|
break;
|
299
|
|
}
|
300
|
1
|
ip2tmp += is2_n;
|
301
|
1
|
ip1tmp += is1_n;
|
302
|
|
}
|
303
|
1
|
op += os_p;
|
304
|
1
|
ip2 += is2_p;
|
305
|
|
}
|
306
|
1
|
op -= ob_p;
|
307
|
1
|
ip2 -= ib2_p;
|
308
|
1
|
ip1 += is1_m;
|
309
|
1
|
op += os_m;
|
310
|
|
}
|
311
|
1
|
}
|
312
|
|
|
313
|
|
NPY_NO_EXPORT void
|
314
|
1
|
OBJECT_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
|
315
|
|
void *_ip2, npy_intp is2_n, npy_intp is2_p,
|
316
|
|
void *_op, npy_intp os_m, npy_intp os_p,
|
317
|
|
npy_intp dm, npy_intp dn, npy_intp dp)
|
318
|
|
{
|
319
|
1
|
char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;
|
320
|
|
|
321
|
1
|
npy_intp ib1_n = is1_n * dn;
|
322
|
1
|
npy_intp ib2_n = is2_n * dn;
|
323
|
1
|
npy_intp ib2_p = is2_p * dp;
|
324
|
1
|
npy_intp ob_p = os_p * dp;
|
325
|
|
|
326
|
1
|
PyObject *product, *sum_of_products = NULL;
|
327
|
|
|
328
|
1
|
for (npy_intp m = 0; m < dm; m++) {
|
329
|
1
|
for (npy_intp p = 0; p < dp; p++) {
|
330
|
1
|
if ( 0 == dn ) {
|
331
|
1
|
sum_of_products = PyLong_FromLong(0);
|
332
|
1
|
if (sum_of_products == NULL) {
|
333
|
|
return;
|
334
|
|
}
|
335
|
|
}
|
336
|
|
|
337
|
1
|
for (npy_intp n = 0; n < dn; n++) {
|
338
|
1
|
PyObject *obj1 = *(PyObject**)ip1, *obj2 = *(PyObject**)ip2;
|
339
|
1
|
if (obj1 == NULL) {
|
340
|
0
|
obj1 = Py_None;
|
341
|
|
}
|
342
|
1
|
if (obj2 == NULL) {
|
343
|
0
|
obj2 = Py_None;
|
344
|
|
}
|
345
|
|
|
346
|
1
|
product = PyNumber_Multiply(obj1, obj2);
|
347
|
1
|
if (product == NULL) {
|
348
|
1
|
Py_XDECREF(sum_of_products);
|
349
|
|
return;
|
350
|
|
}
|
351
|
|
|
352
|
1
|
if (n == 0) {
|
353
|
|
sum_of_products = product;
|
354
|
|
}
|
355
|
|
else {
|
356
|
1
|
Py_SETREF(sum_of_products, PyNumber_Add(sum_of_products, product));
|
357
|
1
|
Py_DECREF(product);
|
358
|
1
|
if (sum_of_products == NULL) {
|
359
|
|
return;
|
360
|
|
}
|
361
|
|
}
|
362
|
|
|
363
|
1
|
ip2 += is2_n;
|
364
|
1
|
ip1 += is1_n;
|
365
|
|
}
|
366
|
|
|
367
|
1
|
*((PyObject **)op) = sum_of_products;
|
368
|
1
|
ip1 -= ib1_n;
|
369
|
1
|
ip2 -= ib2_n;
|
370
|
1
|
op += os_p;
|
371
|
1
|
ip2 += is2_p;
|
372
|
|
}
|
373
|
1
|
op -= ob_p;
|
374
|
1
|
ip2 -= ib2_p;
|
375
|
1
|
ip1 += is1_m;
|
376
|
1
|
op += os_m;
|
377
|
|
}
|
378
|
|
}
|
379
|
|
|
380
|
|
|
381
|
|
/**begin repeat
|
382
|
|
* #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF,
|
383
|
|
* CFLOAT, CDOUBLE, CLONGDOUBLE,
|
384
|
|
* UBYTE, USHORT, UINT, ULONG, ULONGLONG,
|
385
|
|
* BYTE, SHORT, INT, LONG, LONGLONG,
|
386
|
|
* BOOL, OBJECT#
|
387
|
|
* #typ = npy_float,npy_double,npy_longdouble, npy_half,
|
388
|
|
* npy_cfloat, npy_cdouble, npy_clongdouble,
|
389
|
|
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
|
390
|
|
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
|
391
|
|
* npy_bool,npy_object#
|
392
|
|
* #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*12#
|
393
|
|
* #USEBLAS = 1, 1, 0, 0, 1, 1, 0*13#
|
394
|
|
*/
|
395
|
|
|
396
|
|
|
397
|
|
NPY_NO_EXPORT void
|
398
|
1
|
@TYPE@_matmul(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
|
399
|
|
{
|
400
|
1
|
npy_intp dOuter = *dimensions++;
|
401
|
|
npy_intp iOuter;
|
402
|
1
|
npy_intp s0 = *steps++;
|
403
|
1
|
npy_intp s1 = *steps++;
|
404
|
1
|
npy_intp s2 = *steps++;
|
405
|
1
|
npy_intp dm = dimensions[0];
|
406
|
1
|
npy_intp dn = dimensions[1];
|
407
|
1
|
npy_intp dp = dimensions[2];
|
408
|
1
|
npy_intp is1_m=steps[0], is1_n=steps[1], is2_n=steps[2], is2_p=steps[3],
|
409
|
1
|
os_m=steps[4], os_p=steps[5];
|
410
|
|
#if @USEBLAS@ && defined(HAVE_CBLAS)
|
411
|
1
|
npy_intp sz = sizeof(@typ@);
|
412
|
1
|
npy_bool special_case = (dm == 1 || dn == 1 || dp == 1);
|
413
|
1
|
npy_bool any_zero_dim = (dm == 0 || dn == 0 || dp == 0);
|
414
|
1
|
npy_bool scalar_out = (dm == 1 && dp == 1);
|
415
|
1
|
npy_bool scalar_vec = (dn == 1 && (dp == 1 || dm == 1));
|
416
|
1
|
npy_bool too_big_for_blas = (dm > BLAS_MAXSIZE || dn > BLAS_MAXSIZE ||
|
417
|
|
dp > BLAS_MAXSIZE);
|
418
|
1
|
npy_bool i1_c_blasable = is_blasable2d(is1_m, is1_n, dm, dn, sz);
|
419
|
1
|
npy_bool i2_c_blasable = is_blasable2d(is2_n, is2_p, dn, dp, sz);
|
420
|
1
|
npy_bool i1_f_blasable = is_blasable2d(is1_n, is1_m, dn, dm, sz);
|
421
|
1
|
npy_bool i2_f_blasable = is_blasable2d(is2_p, is2_n, dp, dn, sz);
|
422
|
1
|
npy_bool i1blasable = i1_c_blasable || i1_f_blasable;
|
423
|
1
|
npy_bool i2blasable = i2_c_blasable || i2_f_blasable;
|
424
|
1
|
npy_bool o_c_blasable = is_blasable2d(os_m, os_p, dm, dp, sz);
|
425
|
1
|
npy_bool o_f_blasable = is_blasable2d(os_p, os_m, dp, dm, sz);
|
426
|
1
|
npy_bool vector_matrix = ((dm == 1) && i2blasable &&
|
427
|
1
|
is_blasable2d(is1_n, sz, dn, 1, sz));
|
428
|
1
|
npy_bool matrix_vector = ((dp == 1) && i1blasable &&
|
429
|
1
|
is_blasable2d(is2_n, sz, dn, 1, sz));
|
430
|
|
#endif
|
431
|
|
|
432
|
1
|
for (iOuter = 0; iOuter < dOuter; iOuter++,
|
433
|
1
|
args[0] += s0, args[1] += s1, args[2] += s2) {
|
434
|
1
|
void *ip1=args[0], *ip2=args[1], *op=args[2];
|
435
|
|
#if @USEBLAS@ && defined(HAVE_CBLAS)
|
436
|
|
/*
|
437
|
|
* TODO: refactor this out to a inner_loop_selector, in
|
438
|
|
* PyUFunc_MatmulLoopSelector. But that call does not have access to
|
439
|
|
* n, m, p and strides.
|
440
|
|
*/
|
441
|
1
|
if (too_big_for_blas || any_zero_dim) {
|
442
|
1
|
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
|
443
|
|
ip2, is2_n, is2_p,
|
444
|
|
op, os_m, os_p, dm, dn, dp);
|
445
|
|
}
|
446
|
1
|
else if (special_case) {
|
447
|
|
/* Special case variants that have a 1 in the core dimensions */
|
448
|
1
|
if (scalar_out) {
|
449
|
|
/* row @ column, 1,1 output */
|
450
|
1
|
@TYPE@_dot(ip1, is1_n, ip2, is2_n, op, dn, NULL);
|
451
|
1
|
} else if (scalar_vec){
|
452
|
|
/*
|
453
|
|
* 1,1d @ vector or vector @ 1,1d
|
454
|
|
* could use cblas_Xaxy, but that requires 0ing output
|
455
|
|
* and would not be faster (XXX prove it)
|
456
|
|
*/
|
457
|
1
|
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
|
458
|
|
ip2, is2_n, is2_p,
|
459
|
|
op, os_m, os_p, dm, dn, dp);
|
460
|
1
|
} else if (vector_matrix) {
|
461
|
|
/* vector @ matrix, switch ip1, ip2, p and m */
|
462
|
1
|
@TYPE@_gemv(ip2, is2_p, is2_n, ip1, is1_n, is1_m,
|
463
|
|
op, os_p, os_m, dp, dn, dm);
|
464
|
1
|
} else if (matrix_vector) {
|
465
|
|
/* matrix @ vector */
|
466
|
1
|
@TYPE@_gemv(ip1, is1_m, is1_n, ip2, is2_n, is2_p,
|
467
|
|
|
468
|
|
op, os_m, os_p, dm, dn, dp);
|
469
|
|
} else {
|
470
|
|
/* column @ row, 2d output, no blas needed or non-blas-able input */
|
471
|
1
|
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
|
472
|
|
ip2, is2_n, is2_p,
|
473
|
|
op, os_m, os_p, dm, dn, dp);
|
474
|
|
}
|
475
|
|
} else {
|
476
|
|
/* matrix @ matrix */
|
477
|
1
|
if (i1blasable && i2blasable && o_c_blasable) {
|
478
|
1
|
@TYPE@_matmul_matrixmatrix(ip1, is1_m, is1_n,
|
479
|
|
ip2, is2_n, is2_p,
|
480
|
|
op, os_m, os_p,
|
481
|
|
dm, dn, dp);
|
482
|
1
|
} else if (i1blasable && i2blasable && o_f_blasable) {
|
483
|
|
/*
|
484
|
|
* Use transpose equivalence:
|
485
|
|
* matmul(a, b, o) == matmul(b.T, a.T, o.T)
|
486
|
|
*/
|
487
|
1
|
@TYPE@_matmul_matrixmatrix(ip2, is2_p, is2_n,
|
488
|
|
ip1, is1_n, is1_m,
|
489
|
|
op, os_p, os_m,
|
490
|
|
dp, dn, dm);
|
491
|
|
} else {
|
492
|
|
/*
|
493
|
|
* If parameters are castable to int and we copy the
|
494
|
|
* non-blasable (or non-ccontiguous output)
|
495
|
|
* we could still use BLAS, see gh-12365.
|
496
|
|
*/
|
497
|
1
|
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
|
498
|
|
ip2, is2_n, is2_p,
|
499
|
|
op, os_m, os_p, dm, dn, dp);
|
500
|
|
}
|
501
|
|
}
|
502
|
|
#else
|
503
|
1
|
@TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n,
|
504
|
|
ip2, is2_n, is2_p,
|
505
|
|
op, os_m, os_p, dm, dn, dp);
|
506
|
|
|
507
|
|
#endif
|
508
|
|
}
|
509
|
1
|
}
|
510
|
|
|
511
|
|
/**end repeat**/
|