1
/*
2
 * This module provides a BLAS optimized matrix multiply,
3
 * inner product and dot for numpy arrays
4
 */
5

6
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
7
#define _MULTIARRAYMODULE
8

9
#include <Python.h>
10
#include <assert.h>
11
#include <numpy/arrayobject.h>
12
#include "npy_cblas.h"
13
#include "arraytypes.h"
14
#include "common.h"
15

16

17
static const double oneD[2] = {1.0, 0.0}, zeroD[2] = {0.0, 0.0};
18
static const float oneF[2] = {1.0, 0.0}, zeroF[2] = {0.0, 0.0};
19

20

21
/*
22
 * Helper: dispatch to appropriate cblas_?gemm for typenum.
23
 */
24
static void
25 1
gemm(int typenum, enum CBLAS_ORDER order,
26
     enum CBLAS_TRANSPOSE transA, enum CBLAS_TRANSPOSE transB,
27
     npy_intp m, npy_intp n, npy_intp k,
28
     PyArrayObject *A, npy_intp lda, PyArrayObject *B, npy_intp ldb, PyArrayObject *R)
29
{
30 1
    const void *Adata = PyArray_DATA(A), *Bdata = PyArray_DATA(B);
31 1
    void *Rdata = PyArray_DATA(R);
32 1
    npy_intp ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
33

34 1
    switch (typenum) {
35 1
        case NPY_DOUBLE:
36 1
            CBLAS_FUNC(cblas_dgemm)(order, transA, transB, m, n, k, 1.,
37
                        Adata, lda, Bdata, ldb, 0., Rdata, ldc);
38 1
            break;
39 1
        case NPY_FLOAT:
40 1
            CBLAS_FUNC(cblas_sgemm)(order, transA, transB, m, n, k, 1.f,
41
                        Adata, lda, Bdata, ldb, 0.f, Rdata, ldc);
42 1
            break;
43 1
        case NPY_CDOUBLE:
44 1
            CBLAS_FUNC(cblas_zgemm)(order, transA, transB, m, n, k, oneD,
45
                        Adata, lda, Bdata, ldb, zeroD, Rdata, ldc);
46 1
            break;
47 1
        case NPY_CFLOAT:
48 1
            CBLAS_FUNC(cblas_cgemm)(order, transA, transB, m, n, k, oneF,
49
                        Adata, lda, Bdata, ldb, zeroF, Rdata, ldc);
50 1
            break;
51
    }
52
}
53

54

55
/*
56
 * Helper: dispatch to appropriate cblas_?gemv for typenum.
57
 */
58
static void
59 1
gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
60
     PyArrayObject *A, npy_intp lda, PyArrayObject *X, npy_intp incX,
61
     PyArrayObject *R)
62
{
63 1
    const void *Adata = PyArray_DATA(A), *Xdata = PyArray_DATA(X);
64 1
    void *Rdata = PyArray_DATA(R);
65

66 1
    npy_intp m = PyArray_DIM(A, 0), n = PyArray_DIM(A, 1);
67

68 1
    switch (typenum) {
69 1
        case NPY_DOUBLE:
70 1
            CBLAS_FUNC(cblas_dgemv)(order, trans, m, n, 1., Adata, lda, Xdata, incX,
71
                        0., Rdata, 1);
72 1
            break;
73 1
        case NPY_FLOAT:
74 1
            CBLAS_FUNC(cblas_sgemv)(order, trans, m, n, 1.f, Adata, lda, Xdata, incX,
75
                        0.f, Rdata, 1);
76 1
            break;
77 1
        case NPY_CDOUBLE:
78 1
            CBLAS_FUNC(cblas_zgemv)(order, trans, m, n, oneD, Adata, lda, Xdata, incX,
79
                        zeroD, Rdata, 1);
80 1
            break;
81 1
        case NPY_CFLOAT:
82 1
            CBLAS_FUNC(cblas_cgemv)(order, trans, m, n, oneF, Adata, lda, Xdata, incX,
83
                        zeroF, Rdata, 1);
84 1
            break;
85
    }
86
}
87

88

89
/*
90
 * Helper: dispatch to appropriate cblas_?syrk for typenum.
91
 */
92
static void
93 1
syrk(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
94
     npy_intp n, npy_intp k,
95
     PyArrayObject *A, npy_intp lda, PyArrayObject *R)
96
{
97 1
    const void *Adata = PyArray_DATA(A);
98 1
    void *Rdata = PyArray_DATA(R);
99 1
    npy_intp ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
100

101
    npy_intp i;
102
    npy_intp j;
103

104 1
    switch (typenum) {
105 1
        case NPY_DOUBLE:
106 1
            CBLAS_FUNC(cblas_dsyrk)(order, CblasUpper, trans, n, k, 1.,
107
                        Adata, lda, 0., Rdata, ldc);
108

109 1
            for (i = 0; i < n; i++) {
110 1
                for (j = i + 1; j < n; j++) {
111 1
                    *((npy_double*)PyArray_GETPTR2(R, j, i)) =
112 1
                            *((npy_double*)PyArray_GETPTR2(R, i, j));
113
                }
114
            }
115
            break;
116 1
        case NPY_FLOAT:
117 1
            CBLAS_FUNC(cblas_ssyrk)(order, CblasUpper, trans, n, k, 1.f,
118
                        Adata, lda, 0.f, Rdata, ldc);
119

120 1
            for (i = 0; i < n; i++) {
121 1
                for (j = i + 1; j < n; j++) {
122 1
                    *((npy_float*)PyArray_GETPTR2(R, j, i)) =
123 1
                            *((npy_float*)PyArray_GETPTR2(R, i, j));
124
                }
125
            }
126
            break;
127 1
        case NPY_CDOUBLE:
128 1
            CBLAS_FUNC(cblas_zsyrk)(order, CblasUpper, trans, n, k, oneD,
129
                        Adata, lda, zeroD, Rdata, ldc);
130

131 1
            for (i = 0; i < n; i++) {
132 1
                for (j = i + 1; j < n; j++) {
133 1
                    *((npy_cdouble*)PyArray_GETPTR2(R, j, i)) =
134 1
                            *((npy_cdouble*)PyArray_GETPTR2(R, i, j));
135
                }
136
            }
137
            break;
138 1
        case NPY_CFLOAT:
139 1
            CBLAS_FUNC(cblas_csyrk)(order, CblasUpper, trans, n, k, oneF,
140
                        Adata, lda, zeroF, Rdata, ldc);
141

142 1
            for (i = 0; i < n; i++) {
143 1
                for (j = i + 1; j < n; j++) {
144 1
                    *((npy_cfloat*)PyArray_GETPTR2(R, j, i)) =
145 1
                            *((npy_cfloat*)PyArray_GETPTR2(R, i, j));
146
                }
147
            }
148
            break;
149
    }
150
}
151

152

153
typedef enum {_scalar, _column, _row, _matrix} MatrixShape;
154

155

156
static MatrixShape
157 1
_select_matrix_shape(PyArrayObject *array)
158
{
159 1
    switch (PyArray_NDIM(array)) {
160
        case 0:
161
            return _scalar;
162 1
        case 1:
163 1
            if (PyArray_DIM(array, 0) > 1)
164
                return _column;
165 1
            return _scalar;
166 1
        case 2:
167 1
            if (PyArray_DIM(array, 0) > 1) {
168 1
                if (PyArray_DIM(array, 1) == 1)
169
                    return _column;
170
                else
171 1
                    return _matrix;
172
            }
173 1
            if (PyArray_DIM(array, 1) == 1)
174
                return _scalar;
175 1
            return _row;
176
    }
177 0
    return _matrix;
178
}
179

180

181
/*
182
 * This also makes sure that the data segment is aligned with
183
 * an itemsize address as well by returning one if not true.
184
 */
185
NPY_NO_EXPORT int
186 1
_bad_strides(PyArrayObject *ap)
187
{
188 1
    int itemsize = PyArray_ITEMSIZE(ap);
189 1
    int i, N=PyArray_NDIM(ap);
190 1
    npy_intp *strides = PyArray_STRIDES(ap);
191 1
    npy_intp *dims = PyArray_DIMS(ap);
192

193 1
    if (((npy_intp)(PyArray_DATA(ap)) % itemsize) != 0) {
194
        return 1;
195
    }
196 1
    for (i = 0; i < N; i++) {
197 1
        if ((strides[i] < 0) || (strides[i] % itemsize) != 0) {
198
            return 1;
199
        }
200 1
        if ((strides[i] == 0 && dims[i] > 1)) {
201
            return 1;
202
        }
203
    }
204

205
    return 0;
206
}
207

208
/*
209
 * dot(a,b)
210
 * Returns the dot product of a and b for arrays of floating point types.
211
 * Like the generic numpy equivalent the product sum is over
212
 * the last dimension of a and the second-to-last dimension of b.
213
 * NB: The first argument is not conjugated.;
214
 *
215
 * This is for use by PyArray_MatrixProduct2. It is assumed on entry that
216
 * the arrays ap1 and ap2 have a common data type given by typenum that is
217
 * float, double, cfloat, or cdouble and have dimension <= 2. The
218
 * __array_ufunc__ nonsense is also assumed to have been taken care of.
219
 */
220
NPY_NO_EXPORT PyObject *
221 1
cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
222
                    PyArrayObject *out)
223
{
224 1
    PyArrayObject *result = NULL, *out_buf = NULL;
225
    npy_intp j, lda, ldb;
226
    npy_intp l;
227
    int nd;
228 1
    npy_intp ap1stride = 0;
229
    npy_intp dimensions[NPY_MAXDIMS];
230
    npy_intp numbytes;
231
    MatrixShape ap1shape, ap2shape;
232

233 1
    if (_bad_strides(ap1)) {
234 1
            PyObject *op1 = PyArray_NewCopy(ap1, NPY_ANYORDER);
235

236 1
            Py_DECREF(ap1);
237 1
            ap1 = (PyArrayObject *)op1;
238 1
            if (ap1 == NULL) {
239
                goto fail;
240
            }
241
    }
242 1
    if (_bad_strides(ap2)) {
243 1
            PyObject *op2 = PyArray_NewCopy(ap2, NPY_ANYORDER);
244

245 1
            Py_DECREF(ap2);
246 1
            ap2 = (PyArrayObject *)op2;
247 1
            if (ap2 == NULL) {
248
                goto fail;
249
            }
250
    }
251 1
    ap1shape = _select_matrix_shape(ap1);
252 1
    ap2shape = _select_matrix_shape(ap2);
253

254 1
    if (ap1shape == _scalar || ap2shape == _scalar) {
255
        PyArrayObject *oap1, *oap2;
256 1
        oap1 = ap1; oap2 = ap2;
257
        /* One of ap1 or ap2 is a scalar */
258 1
        if (ap1shape == _scalar) {
259
            /* Make ap2 the scalar */
260 1
            PyArrayObject *t = ap1;
261 1
            ap1 = ap2;
262 1
            ap2 = t;
263 1
            ap1shape = ap2shape;
264 1
            ap2shape = _scalar;
265
        }
266

267 1
        if (ap1shape == _row) {
268 1
            ap1stride = PyArray_STRIDE(ap1, 1);
269
        }
270 1
        else if (PyArray_NDIM(ap1) > 0) {
271 1
            ap1stride = PyArray_STRIDE(ap1, 0);
272
        }
273

274 1
        if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
275
            npy_intp *thisdims;
276 1
            if (PyArray_NDIM(ap1) == 0) {
277 1
                nd = PyArray_NDIM(ap2);
278 1
                thisdims = PyArray_DIMS(ap2);
279
            }
280
            else {
281 1
                nd = PyArray_NDIM(ap1);
282 1
                thisdims = PyArray_DIMS(ap1);
283
            }
284 1
            l = 1;
285 1
            for (j = 0; j < nd; j++) {
286 1
                dimensions[j] = thisdims[j];
287 1
                l *= dimensions[j];
288
            }
289
        }
290
        else {
291 1
            l = PyArray_DIM(oap1, PyArray_NDIM(oap1) - 1);
292

293 1
            if (PyArray_DIM(oap2, 0) != l) {
294 0
                dot_alignment_error(oap1, PyArray_NDIM(oap1) - 1, oap2, 0);
295 0
                goto fail;
296
            }
297 1
            nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;
298
            /*
299
             * nd = 0 or 1 or 2. If nd == 0 do nothing ...
300
             */
301 1
            if (nd == 1) {
302
                /*
303
                 * Either PyArray_NDIM(ap1) is 1 dim or PyArray_NDIM(ap2) is
304
                 * 1 dim and the other is 2 dim
305
                 */
306 1
                dimensions[0] = (PyArray_NDIM(oap1) == 2) ?
307 1
                                PyArray_DIM(oap1, 0) : PyArray_DIM(oap2, 1);
308 1
                l = dimensions[0];
309
                /*
310
                 * Fix it so that dot(shape=(N,1), shape=(1,))
311
                 * and dot(shape=(1,), shape=(1,N)) both return
312
                 * an (N,) array (but use the fast scalar code)
313
                 */
314
            }
315 1
            else if (nd == 2) {
316 1
                dimensions[0] = PyArray_DIM(oap1, 0);
317 1
                dimensions[1] = PyArray_DIM(oap2, 1);
318
                /*
319
                 * We need to make sure that dot(shape=(1,1), shape=(1,N))
320
                 * and dot(shape=(N,1),shape=(1,1)) uses
321
                 * scalar multiplication appropriately
322
                 */
323 1
                if (ap1shape == _row) {
324
                    l = dimensions[1];
325
                }
326
                else {
327 1
                    l = dimensions[0];
328
                }
329
            }
330

331
            /* Check if the summation dimension is 0-sized */
332 1
            if (PyArray_DIM(oap1, PyArray_NDIM(oap1) - 1) == 0) {
333 1
                l = 0;
334
            }
335
        }
336
    }
337
    else {
338
        /*
339
         * (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2)
340
         * Both ap1 and ap2 are vectors or matrices
341
         */
342 1
        l = PyArray_DIM(ap1, PyArray_NDIM(ap1) - 1);
343

344 1
        if (PyArray_DIM(ap2, 0) != l) {
345 1
            dot_alignment_error(ap1, PyArray_NDIM(ap1) - 1, ap2, 0);
346 1
            goto fail;
347
        }
348 1
        nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;
349

350 1
        if (nd == 1) {
351 1
            dimensions[0] = (PyArray_NDIM(ap1) == 2) ?
352 1
                            PyArray_DIM(ap1, 0) : PyArray_DIM(ap2, 1);
353
        }
354 1
        else if (nd == 2) {
355 1
            dimensions[0] = PyArray_DIM(ap1, 0);
356 1
            dimensions[1] = PyArray_DIM(ap2, 1);
357
        }
358
    }
359

360 1
    out_buf = new_array_for_sum(ap1, ap2, out, nd, dimensions, typenum, &result);
361 1
    if (out_buf == NULL) {
362
        goto fail;
363
    }
364

365 1
    numbytes = PyArray_NBYTES(out_buf);
366 1
    memset(PyArray_DATA(out_buf), 0, numbytes);
367 1
    if (numbytes == 0 || l == 0) {
368 1
            Py_DECREF(ap1);
369 1
            Py_DECREF(ap2);
370 1
            Py_DECREF(out_buf);
371 1
            return PyArray_Return(result);
372
    }
373

374 1
    if (ap2shape == _scalar) {
375
        /*
376
         * Multiplication by a scalar -- Level 1 BLAS
377
         * if ap1shape is a matrix and we are not contiguous, then we can't
378
         * just blast through the entire array using a single striding factor
379
         */
380 1
        NPY_BEGIN_ALLOW_THREADS;
381

382 1
        if (typenum == NPY_DOUBLE) {
383 1
            if (l == 1) {
384 1
                *((double *)PyArray_DATA(out_buf)) = *((double *)PyArray_DATA(ap2)) *
385 1
                                                 *((double *)PyArray_DATA(ap1));
386
            }
387 1
            else if (ap1shape != _matrix) {
388 1
                CBLAS_FUNC(cblas_daxpy)(l,
389 1
                            *((double *)PyArray_DATA(ap2)),
390 1
                            (double *)PyArray_DATA(ap1),
391 1
                            ap1stride/sizeof(double),
392 1
                            (double *)PyArray_DATA(out_buf), 1);
393
            }
394
            else {
395
                int maxind, oind;
396
                npy_intp i, a1s, outs;
397
                char *ptr, *optr;
398
                double val;
399

400 1
                maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1);
401 1
                oind = 1 - maxind;
402 1
                ptr = PyArray_DATA(ap1);
403 1
                optr = PyArray_DATA(out_buf);
404 1
                l = PyArray_DIM(ap1, maxind);
405 1
                val = *((double *)PyArray_DATA(ap2));
406 1
                a1s = PyArray_STRIDE(ap1, maxind) / sizeof(double);
407 1
                outs = PyArray_STRIDE(out_buf, maxind) / sizeof(double);
408 1
                for (i = 0; i < PyArray_DIM(ap1, oind); i++) {
409 1
                    CBLAS_FUNC(cblas_daxpy)(l, val, (double *)ptr, a1s,
410
                                (double *)optr, outs);
411 1
                    ptr += PyArray_STRIDE(ap1, oind);
412 1
                    optr += PyArray_STRIDE(out_buf, oind);
413
                }
414
            }
415
        }
416 1
        else if (typenum == NPY_CDOUBLE) {
417 1
            if (l == 1) {
418
                npy_cdouble *ptr1, *ptr2, *res;
419

420 1
                ptr1 = (npy_cdouble *)PyArray_DATA(ap2);
421 1
                ptr2 = (npy_cdouble *)PyArray_DATA(ap1);
422 1
                res = (npy_cdouble *)PyArray_DATA(out_buf);
423 1
                res->real = ptr1->real * ptr2->real - ptr1->imag * ptr2->imag;
424 1
                res->imag = ptr1->real * ptr2->imag + ptr1->imag * ptr2->real;
425
            }
426 1
            else if (ap1shape != _matrix) {
427 1
                CBLAS_FUNC(cblas_zaxpy)(l,
428 1
                            (double *)PyArray_DATA(ap2),
429 1
                            (double *)PyArray_DATA(ap1),
430 1
                            ap1stride/sizeof(npy_cdouble),
431
                            (double *)PyArray_DATA(out_buf), 1);
432
            }
433
            else {
434
                int maxind, oind;
435
                npy_intp i, a1s, outs;
436
                char *ptr, *optr;
437
                double *pval;
438

439 1
                maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1);
440 1
                oind = 1 - maxind;
441 1
                ptr = PyArray_DATA(ap1);
442 1
                optr = PyArray_DATA(out_buf);
443 1
                l = PyArray_DIM(ap1, maxind);
444 1
                pval = (double *)PyArray_DATA(ap2);
445 1
                a1s = PyArray_STRIDE(ap1, maxind) / sizeof(npy_cdouble);
446 1
                outs = PyArray_STRIDE(out_buf, maxind) / sizeof(npy_cdouble);
447 1
                for (i = 0; i < PyArray_DIM(ap1, oind); i++) {
448 1
                    CBLAS_FUNC(cblas_zaxpy)(l, pval, (double *)ptr, a1s,
449
                                (double *)optr, outs);
450 1
                    ptr += PyArray_STRIDE(ap1, oind);
451 1
                    optr += PyArray_STRIDE(out_buf, oind);
452
                }
453
            }
454
        }
455 1
        else if (typenum == NPY_FLOAT) {
456 1
            if (l == 1) {
457 1
                *((float *)PyArray_DATA(out_buf)) = *((float *)PyArray_DATA(ap2)) *
458 1
                    *((float *)PyArray_DATA(ap1));
459
            }
460 1
            else if (ap1shape != _matrix) {
461 1
                CBLAS_FUNC(cblas_saxpy)(l,
462 1
                            *((float *)PyArray_DATA(ap2)),
463 1
                            (float *)PyArray_DATA(ap1),
464 1
                            ap1stride/sizeof(float),
465 1
                            (float *)PyArray_DATA(out_buf), 1);
466
            }
467
            else {
468
                int maxind, oind;
469
                npy_intp i, a1s, outs;
470
                char *ptr, *optr;
471
                float val;
472

473 1
                maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1);
474 1
                oind = 1 - maxind;
475 1
                ptr = PyArray_DATA(ap1);
476 1
                optr = PyArray_DATA(out_buf);
477 1
                l = PyArray_DIM(ap1, maxind);
478 1
                val = *((float *)PyArray_DATA(ap2));
479 1
                a1s = PyArray_STRIDE(ap1, maxind) / sizeof(float);
480 1
                outs = PyArray_STRIDE(out_buf, maxind) / sizeof(float);
481 1
                for (i = 0; i < PyArray_DIM(ap1, oind); i++) {
482 1
                    CBLAS_FUNC(cblas_saxpy)(l, val, (float *)ptr, a1s,
483
                                (float *)optr, outs);
484 1
                    ptr += PyArray_STRIDE(ap1, oind);
485 1
                    optr += PyArray_STRIDE(out_buf, oind);
486
                }
487
            }
488
        }
489 1
        else if (typenum == NPY_CFLOAT) {
490 1
            if (l == 1) {
491
                npy_cfloat *ptr1, *ptr2, *res;
492

493 1
                ptr1 = (npy_cfloat *)PyArray_DATA(ap2);
494 1
                ptr2 = (npy_cfloat *)PyArray_DATA(ap1);
495 1
                res = (npy_cfloat *)PyArray_DATA(out_buf);
496 1
                res->real = ptr1->real * ptr2->real - ptr1->imag * ptr2->imag;
497 1
                res->imag = ptr1->real * ptr2->imag + ptr1->imag * ptr2->real;
498
            }
499 1
            else if (ap1shape != _matrix) {
500 1
                CBLAS_FUNC(cblas_caxpy)(l,
501 1
                            (float *)PyArray_DATA(ap2),
502 1
                            (float *)PyArray_DATA(ap1),
503 1
                            ap1stride/sizeof(npy_cfloat),
504
                            (float *)PyArray_DATA(out_buf), 1);
505
            }
506
            else {
507
                int maxind, oind;
508
                npy_intp i, a1s, outs;
509
                char *ptr, *optr;
510
                float *pval;
511

512 1
                maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1);
513 1
                oind = 1 - maxind;
514 1
                ptr = PyArray_DATA(ap1);
515 1
                optr = PyArray_DATA(out_buf);
516 1
                l = PyArray_DIM(ap1, maxind);
517 1
                pval = (float *)PyArray_DATA(ap2);
518 1
                a1s = PyArray_STRIDE(ap1, maxind) / sizeof(npy_cfloat);
519 1
                outs = PyArray_STRIDE(out_buf, maxind) / sizeof(npy_cfloat);
520 1
                for (i = 0; i < PyArray_DIM(ap1, oind); i++) {
521 1
                    CBLAS_FUNC(cblas_caxpy)(l, pval, (float *)ptr, a1s,
522
                                (float *)optr, outs);
523 1
                    ptr += PyArray_STRIDE(ap1, oind);
524 1
                    optr += PyArray_STRIDE(out_buf, oind);
525
                }
526
            }
527
        }
528 1
        NPY_END_ALLOW_THREADS;
529
    }
530 1
    else if ((ap2shape == _column) && (ap1shape != _matrix)) {
531 1
        NPY_BEGIN_ALLOW_THREADS;
532

533
        /* Dot product between two vectors -- Level 1 BLAS */
534 1
        PyArray_DESCR(out_buf)->f->dotfunc(
535
                 PyArray_DATA(ap1), PyArray_STRIDE(ap1, (ap1shape == _row)),
536
                 PyArray_DATA(ap2), PyArray_STRIDE(ap2, 0),
537
                 PyArray_DATA(out_buf), l, NULL);
538 1
        NPY_END_ALLOW_THREADS;
539
    }
540 1
    else if (ap1shape == _matrix && ap2shape != _matrix) {
541
        /* Matrix vector multiplication -- Level 2 BLAS */
542
        /* lda must be MAX(M,1) */
543
        enum CBLAS_ORDER Order;
544
        npy_intp ap2s;
545

546 1
        if (!PyArray_ISONESEGMENT(ap1)) {
547
            PyObject *new;
548 1
            new = PyArray_Copy(ap1);
549 1
            Py_DECREF(ap1);
550 1
            ap1 = (PyArrayObject *)new;
551 1
            if (new == NULL) {
552
                goto fail;
553
            }
554
        }
555 1
        NPY_BEGIN_ALLOW_THREADS
556 1
        if (PyArray_ISCONTIGUOUS(ap1)) {
557 1
            Order = CblasRowMajor;
558 1
            lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
559
        }
560
        else {
561 1
            Order = CblasColMajor;
562 1
            lda = (PyArray_DIM(ap1, 0) > 1 ? PyArray_DIM(ap1, 0) : 1);
563
        }
564 1
        ap2s = PyArray_STRIDE(ap2, 0) / PyArray_ITEMSIZE(ap2);
565 1
        gemv(typenum, Order, CblasNoTrans, ap1, lda, ap2, ap2s, out_buf);
566 1
        NPY_END_ALLOW_THREADS;
567
    }
568 1
    else if (ap1shape != _matrix && ap2shape == _matrix) {
569
        /* Vector matrix multiplication -- Level 2 BLAS */
570
        enum CBLAS_ORDER Order;
571
        npy_intp ap1s;
572

573 1
        if (!PyArray_ISONESEGMENT(ap2)) {
574
            PyObject *new;
575 1
            new = PyArray_Copy(ap2);
576 1
            Py_DECREF(ap2);
577 1
            ap2 = (PyArrayObject *)new;
578 1
            if (new == NULL) {
579
                goto fail;
580
            }
581
        }
582 1
        NPY_BEGIN_ALLOW_THREADS
583 1
        if (PyArray_ISCONTIGUOUS(ap2)) {
584 1
            Order = CblasRowMajor;
585 1
            lda = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
586
        }
587
        else {
588 1
            Order = CblasColMajor;
589 1
            lda = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
590
        }
591 1
        if (ap1shape == _row) {
592 1
            ap1s = PyArray_STRIDE(ap1, 1) / PyArray_ITEMSIZE(ap1);
593
        }
594
        else {
595 1
            ap1s = PyArray_STRIDE(ap1, 0) / PyArray_ITEMSIZE(ap1);
596
        }
597 1
        gemv(typenum, Order, CblasTrans, ap2, lda, ap1, ap1s, out_buf);
598 1
        NPY_END_ALLOW_THREADS;
599
    }
600
    else {
601
        /*
602
         * (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2)
603
         * Matrix matrix multiplication -- Level 3 BLAS
604
         *  L x M  multiplied by M x N
605
         */
606
        enum CBLAS_ORDER Order;
607
        enum CBLAS_TRANSPOSE Trans1, Trans2;
608
        npy_intp M, N, L;
609

610
        /* Optimization possible: */
611
        /*
612
         * We may be able to handle single-segment arrays here
613
         * using appropriate values of Order, Trans1, and Trans2.
614
         */
615 1
        if (!PyArray_IS_C_CONTIGUOUS(ap2) && !PyArray_IS_F_CONTIGUOUS(ap2)) {
616 1
            PyObject *new = PyArray_Copy(ap2);
617

618 1
            Py_DECREF(ap2);
619 1
            ap2 = (PyArrayObject *)new;
620 1
            if (new == NULL) {
621
                goto fail;
622
            }
623
        }
624 1
        if (!PyArray_IS_C_CONTIGUOUS(ap1) && !PyArray_IS_F_CONTIGUOUS(ap1)) {
625 1
            PyObject *new = PyArray_Copy(ap1);
626

627 1
            Py_DECREF(ap1);
628 1
            ap1 = (PyArrayObject *)new;
629 1
            if (new == NULL) {
630
                goto fail;
631
            }
632
        }
633

634 1
        NPY_BEGIN_ALLOW_THREADS;
635

636 1
        Order = CblasRowMajor;
637 1
        Trans1 = CblasNoTrans;
638 1
        Trans2 = CblasNoTrans;
639 1
        L = PyArray_DIM(ap1, 0);
640 1
        N = PyArray_DIM(ap2, 1);
641 1
        M = PyArray_DIM(ap2, 0);
642 1
        lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
643 1
        ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
644

645
        /*
646
         * Avoid temporary copies for arrays in Fortran order
647
         */
648 1
        if (PyArray_IS_F_CONTIGUOUS(ap1)) {
649 1
            Trans1 = CblasTrans;
650 1
            lda = (PyArray_DIM(ap1, 0) > 1 ? PyArray_DIM(ap1, 0) : 1);
651
        }
652 1
        if (PyArray_IS_F_CONTIGUOUS(ap2)) {
653 1
            Trans2 = CblasTrans;
654 1
            ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
655
        }
656

657
        /*
658
         * Use syrk if we have a case of a matrix times its transpose.
659
         * Otherwise, use gemm for all other cases.
660
         */
661 1
        if (
662 1
            (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) &&
663 1
            (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) &&
664 1
            (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) &&
665 1
            (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) &&
666 1
            (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) &&
667 1
            ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) &&
668
            ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans))
669
        ) {
670 1
            if (Trans1 == CblasNoTrans) {
671 1
                syrk(typenum, Order, Trans1, N, M, ap1, lda, out_buf);
672
            }
673
            else {
674 1
                syrk(typenum, Order, Trans1, N, M, ap2, ldb, out_buf);
675
            }
676
        }
677
        else {
678 1
            gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb,
679
                 out_buf);
680
        }
681 1
        NPY_END_ALLOW_THREADS;
682
    }
683

684

685 1
    Py_DECREF(ap1);
686 1
    Py_DECREF(ap2);
687

688
    /* Trigger possible copyback into `result` */
689 1
    PyArray_ResolveWritebackIfCopy(out_buf);
690 1
    Py_DECREF(out_buf);
691

692 1
    return PyArray_Return(result);
693

694 1
fail:
695 1
    Py_XDECREF(ap1);
696 1
    Py_XDECREF(ap2);
697 1
    Py_XDECREF(out_buf);
698 1
    Py_XDECREF(result);
699
    return NULL;
700
}

Read our documentation on viewing source code .

Loading