1
/* -*- c -*- */
2

3
#define _UMATHMODULE
4
#define _MULTIARRAYMODULE
5
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
6

7
#include "Python.h"
8

9
#include "npy_config.h"
10
#include "numpy/npy_common.h"
11
#include "numpy/arrayobject.h"
12
#include "numpy/ufuncobject.h"
13
#include "numpy/npy_math.h"
14
#include "numpy/halffloat.h"
15
#include "lowlevel_strided_loops.h"
16

17
#include "npy_pycompat.h"
18

19
#include "npy_cblas.h"
20
#include "arraytypes.h" /* For TYPE_dot functions */
21

22
#include <assert.h>
23

24
/*
25
 *****************************************************************************
26
 **                            BASICS                                       **
27
 *****************************************************************************
28
 */
29

30
#if defined(HAVE_CBLAS)
31
/*
32
 * -1 to be conservative, in case blas internally uses a for loop with an
33
 * inclusive upper bound
34
 */
35
#ifndef HAVE_BLAS_ILP64
36
#define BLAS_MAXSIZE (NPY_MAX_INT - 1)
37
#else
38
#define BLAS_MAXSIZE (NPY_MAX_INT64 - 1)
39
#endif
40

41
/*
42
 * Determine if a 2d matrix can be used by BLAS
43
 * 1. Strides must not alias or overlap
44
 * 2. The faster (second) axis must be contiguous
45
 * 3. The slower (first) axis stride, in unit steps, must be larger than
46
 *    the faster axis dimension
47
 */
48
static NPY_INLINE npy_bool
49
is_blasable2d(npy_intp byte_stride1, npy_intp byte_stride2,
50
              npy_intp d1, npy_intp d2,  npy_intp itemsize)
51
{
52 1
    npy_intp unit_stride1 = byte_stride1 / itemsize;
53 1
    if (byte_stride2 != itemsize) {
54
        return NPY_FALSE;
55
    }
56 1
    if ((byte_stride1 % itemsize ==0) &&
57 1
        (unit_stride1 >= d2) &&
58 1
        (unit_stride1 <= BLAS_MAXSIZE))
59
    {
60
        return NPY_TRUE;
61
    }
62
    return NPY_FALSE;
63
}
64

65
static const npy_cdouble oneD = {1.0, 0.0}, zeroD = {0.0, 0.0};
66
static const npy_cfloat  oneF = {1.0, 0.0}, zeroF = {0.0, 0.0};
67

68
/**begin repeat
69
 *
70
 * #name = FLOAT, DOUBLE, CFLOAT, CDOUBLE#
71
 * #ctype = npy_float, npy_double, npy_cfloat, npy_cdouble#
72
 * #typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
73
 * #prefix = s, d, c, z#
74
 * #step1 = 1.F, 1., &oneF, &oneD#
75
 * #step0 = 0.F, 0., &zeroF, &zeroD#
76
 */
77
NPY_NO_EXPORT void
78 1
@name@_gemv(void *ip1, npy_intp is1_m, npy_intp is1_n,
79
            void *ip2, npy_intp is2_n, npy_intp NPY_UNUSED(is2_p),
80
            void *op, npy_intp op_m, npy_intp NPY_UNUSED(op_p),
81
            npy_intp m, npy_intp n, npy_intp NPY_UNUSED(p))
82
{
83
    /*
84
     * Vector matrix multiplication -- Level 2 BLAS
85
     * arguments
86
     * ip1: contiguous data, m*n shape
87
     * ip2: data in c order, n*1 shape
88
     * op:  data in c order, m shape
89
     */
90
    enum CBLAS_ORDER order;
91
    CBLAS_INT M, N, lda;
92

93
    assert(m <= BLAS_MAXSIZE && n <= BLAS_MAXSIZE);
94
    assert (is_blasable2d(is2_n, sizeof(@typ@), n, 1, sizeof(@typ@)));
95 1
    M = (CBLAS_INT)m;
96 1
    N = (CBLAS_INT)n;
97

98 1
    if (is_blasable2d(is1_m, is1_n, m, n, sizeof(@typ@))) {
99 1
        order = CblasColMajor;
100 1
        lda = (CBLAS_INT)(is1_m / sizeof(@typ@));
101
    }
102
    else {
103
        /* If not ColMajor, caller should have ensured we are RowMajor */
104
        /* will not assert in release mode */
105 1
        order = CblasRowMajor;
106
        assert(is_blasable2d(is1_n, is1_m, n, m, sizeof(@typ@)));
107 1
        lda = (CBLAS_INT)(is1_n / sizeof(@typ@));
108
    }
109 1
    CBLAS_FUNC(cblas_@prefix@gemv)(order, CblasTrans, N, M, @step1@, ip1, lda, ip2,
110 1
                                     is2_n / sizeof(@typ@), @step0@, op, op_m / sizeof(@typ@));
111 1
}
112

113
NPY_NO_EXPORT void
114 1
@name@_matmul_matrixmatrix(void *ip1, npy_intp is1_m, npy_intp is1_n,
115
                           void *ip2, npy_intp is2_n, npy_intp is2_p,
116
                           void *op, npy_intp os_m, npy_intp os_p,
117
                           npy_intp m, npy_intp n, npy_intp p)
118
{
119
    /*
120
     * matrix matrix multiplication -- Level 3 BLAS
121
     */
122 1
    enum CBLAS_ORDER order = CblasRowMajor;
123
    enum CBLAS_TRANSPOSE trans1, trans2;
124
    CBLAS_INT M, N, P, lda, ldb, ldc;
125
    assert(m <= BLAS_MAXSIZE && n <= BLAS_MAXSIZE && p <= BLAS_MAXSIZE);
126 1
    M = (CBLAS_INT)m;
127 1
    N = (CBLAS_INT)n;
128 1
    P = (CBLAS_INT)p;
129

130
    assert(is_blasable2d(os_m, os_p, m, p, sizeof(@typ@)));
131 1
    ldc = (CBLAS_INT)(os_m / sizeof(@typ@));
132

133 1
    if (is_blasable2d(is1_m, is1_n, m, n, sizeof(@typ@))) {
134 1
        trans1 = CblasNoTrans;
135 1
        lda = (CBLAS_INT)(is1_m / sizeof(@typ@));
136
    }
137
    else {
138
        /* If not ColMajor, caller should have ensured we are RowMajor */
139
        /* will not assert in release mode */
140
        assert(is_blasable2d(is1_n, is1_m, n, m, sizeof(@typ@)));
141 1
        trans1 = CblasTrans;
142 1
        lda = (CBLAS_INT)(is1_n / sizeof(@typ@));
143
    }
144

145 1
    if (is_blasable2d(is2_n, is2_p, n, p, sizeof(@typ@))) {
146 1
        trans2 = CblasNoTrans;
147 1
        ldb = (CBLAS_INT)(is2_n / sizeof(@typ@));
148
    }
149
    else {
150
        /* If not ColMajor, caller should have ensured we are RowMajor */
151
        /* will not assert in release mode */
152
        assert(is_blasable2d(is2_p, is2_n, p, n, sizeof(@typ@)));
153 1
        trans2 = CblasTrans;
154 1
        ldb = (CBLAS_INT)(is2_p / sizeof(@typ@));
155
    }
156
    /*
157
     * Use syrk if we have a case of a matrix times its transpose.
158
     * Otherwise, use gemm for all other cases.
159
     */
160 1
    if (
161 1
        (ip1 == ip2) &&
162 1
        (m == p) &&
163 1
        (is1_m == is2_p) &&
164 1
        (is1_n == is2_n) &&
165
        (trans1 != trans2)
166
    ) {
167
        npy_intp i,j;
168 1
        if (trans1 == CblasNoTrans) {
169 1
            CBLAS_FUNC(cblas_@prefix@syrk)(
170
                order, CblasUpper, trans1, P, N, @step1@,
171
                ip1, lda, @step0@, op, ldc);
172
        }
173
        else {
174 1
            CBLAS_FUNC(cblas_@prefix@syrk)(
175
                order, CblasUpper, trans1, P, N, @step1@,
176
                ip1, ldb, @step0@, op, ldc);
177
        }
178
        /* Copy the triangle */
179 1
        for (i = 0; i < P; i++) {
180 1
            for (j = i + 1; j < P; j++) {
181 1
                ((@typ@*)op)[j * ldc + i] = ((@typ@*)op)[i * ldc + j];
182
            }
183
        }
184

185
    }
186
    else {
187 1
        CBLAS_FUNC(cblas_@prefix@gemm)(
188
            order, trans1, trans2, M, P, N, @step1@, ip1, lda,
189
            ip2, ldb, @step0@, op, ldc);
190
    }
191 1
}
192

193
/**end repeat**/
194
#endif
195

196
/*
197
 * matmul loops
198
 * signature is (m?,n),(n,p?)->(m?,p?)
199
 */
200

201
/**begin repeat
202
 *  #TYPE = LONGDOUBLE,
203
 *          FLOAT, DOUBLE, HALF,
204
 *          CFLOAT, CDOUBLE, CLONGDOUBLE,
205
 *          UBYTE, USHORT, UINT, ULONG, ULONGLONG,
206
 *          BYTE, SHORT, INT, LONG, LONGLONG#
207
 *  #typ = npy_longdouble,
208
 *         npy_float,npy_double,npy_half,
209
 *         npy_cfloat, npy_cdouble, npy_clongdouble,
210
 *         npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
211
 *         npy_byte, npy_short, npy_int, npy_long, npy_longlong#
212
 * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*10#
213
 * #IS_HALF = 0, 0, 0, 1, 0*13#
214
 */
215

216
NPY_NO_EXPORT void
217 1
@TYPE@_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
218
                           void *_ip2, npy_intp is2_n, npy_intp is2_p,
219
                           void *_op, npy_intp os_m, npy_intp os_p,
220
                           npy_intp dm, npy_intp dn, npy_intp dp)
221
                           
222
{
223
    npy_intp m, n, p;
224
    npy_intp ib1_n, ib2_n, ib2_p, ob_p;
225 1
    char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;
226

227 1
    ib1_n = is1_n * dn;
228 1
    ib2_n = is2_n * dn;
229 1
    ib2_p = is2_p * dp;
230 1
    ob_p  = os_p * dp;
231

232 1
    for (m = 0; m < dm; m++) {
233 1
        for (p = 0; p < dp; p++) {
234
#if @IS_COMPLEX@ == 1
235 1
            (*(@typ@ *)op).real = 0;
236 1
            (*(@typ@ *)op).imag = 0;
237
#elif @IS_HALF@
238
            float sum = 0;
239
#else
240 1
            *(@typ@ *)op = 0;
241
#endif
242 1
            for (n = 0; n < dn; n++) {
243 1
                @typ@ val1 = (*(@typ@ *)ip1);
244 1
                @typ@ val2 = (*(@typ@ *)ip2);
245
#if @IS_HALF@
246 1
                sum += npy_half_to_float(val1) * npy_half_to_float(val2);
247
#elif @IS_COMPLEX@ == 1
248 1
                (*(@typ@ *)op).real += (val1.real * val2.real) -
249 1
                                       (val1.imag * val2.imag);
250 1
                (*(@typ@ *)op).imag += (val1.real * val2.imag) +
251 1
                                       (val1.imag * val2.real);
252
#else
253 1
                *(@typ@ *)op += val1 * val2;
254
#endif
255 1
                ip2 += is2_n;
256 1
                ip1 += is1_n;
257
            }
258
#if @IS_HALF@
259 1
            *(@typ@ *)op = npy_float_to_half(sum);
260
#endif
261 1
            ip1 -= ib1_n;
262 1
            ip2 -= ib2_n;
263 1
            op  +=  os_p;
264 1
            ip2 += is2_p;
265
        }
266 1
        op -= ob_p;
267 1
        ip2 -= ib2_p;
268 1
        ip1 += is1_m;
269 1
        op  +=  os_m;
270
    }
271 1
}
272

273
/**end repeat**/
274
NPY_NO_EXPORT void
275 1
BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
276
                           void *_ip2, npy_intp is2_n, npy_intp is2_p,
277
                           void *_op, npy_intp os_m, npy_intp os_p,
278
                           npy_intp dm, npy_intp dn, npy_intp dp)
279
                           
280
{
281
    npy_intp m, n, p;
282
    npy_intp ib2_p, ob_p;
283 1
    char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;
284

285 1
    ib2_p = is2_p * dp;
286 1
    ob_p  = os_p * dp;
287

288 1
    for (m = 0; m < dm; m++) {
289 1
        for (p = 0; p < dp; p++) {
290 1
            char *ip1tmp = ip1;
291 1
            char *ip2tmp = ip2;
292 1
            *(npy_bool *)op = NPY_FALSE;
293 1
            for (n = 0; n < dn; n++) {
294 1
                npy_bool val1 = (*(npy_bool *)ip1tmp);
295 1
                npy_bool val2 = (*(npy_bool *)ip2tmp);
296 1
                if (val1 != 0 && val2 != 0) {
297 1
                    *(npy_bool *)op = NPY_TRUE;
298 1
                    break;
299
                }
300 1
                ip2tmp += is2_n;
301 1
                ip1tmp += is1_n;
302
            }
303 1
            op  +=  os_p;
304 1
            ip2 += is2_p;
305
        }
306 1
        op -= ob_p;
307 1
        ip2 -= ib2_p;
308 1
        ip1 += is1_m;
309 1
        op  +=  os_m;
310
    }
311 1
}
312

313
NPY_NO_EXPORT void
314 1
OBJECT_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
315
                           void *_ip2, npy_intp is2_n, npy_intp is2_p,
316
                           void *_op, npy_intp os_m, npy_intp os_p,
317
                           npy_intp dm, npy_intp dn, npy_intp dp)                         
318
{
319 1
    char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;
320

321 1
    npy_intp ib1_n = is1_n * dn;
322 1
    npy_intp ib2_n = is2_n * dn;
323 1
    npy_intp ib2_p = is2_p * dp;
324 1
    npy_intp ob_p  = os_p * dp;
325

326 1
    PyObject *product, *sum_of_products = NULL;
327

328 1
    for (npy_intp m = 0; m < dm; m++) {
329 1
        for (npy_intp p = 0; p < dp; p++) {
330 1
            if ( 0 == dn ) {
331 1
                sum_of_products = PyLong_FromLong(0);
332 1
                if (sum_of_products == NULL) {
333
                    return;
334
                }
335
            }
336

337 1
            for (npy_intp n = 0; n < dn; n++) {
338 1
                PyObject *obj1 = *(PyObject**)ip1, *obj2 = *(PyObject**)ip2;
339 1
                if (obj1 == NULL) {
340 0
                    obj1 = Py_None;
341
                }
342 1
                if (obj2 == NULL) {
343 0
                    obj2 = Py_None;
344
                }
345

346 1
                product = PyNumber_Multiply(obj1, obj2);
347 1
                if (product == NULL) {
348 1
                    Py_XDECREF(sum_of_products);
349
                    return;
350
                }
351

352 1
                if (n == 0) {
353
                    sum_of_products = product;
354
                }
355
                else {
356 1
                    Py_SETREF(sum_of_products, PyNumber_Add(sum_of_products, product));
357 1
                    Py_DECREF(product);
358 1
                    if (sum_of_products == NULL) {
359
                        return;
360
                    }
361
                }
362

363 1
                ip2 += is2_n;
364 1
                ip1 += is1_n;
365
            }
366

367 1
            *((PyObject **)op) = sum_of_products;
368 1
            ip1 -= ib1_n;
369 1
            ip2 -= ib2_n;
370 1
            op  +=  os_p;
371 1
            ip2 += is2_p;
372
        }
373 1
        op -= ob_p;
374 1
        ip2 -= ib2_p;
375 1
        ip1 += is1_m;
376 1
        op  +=  os_m;
377
    }
378
}
379

380

381
/**begin repeat
382
 *  #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF,
383
 *          CFLOAT, CDOUBLE, CLONGDOUBLE,
384
 *          UBYTE, USHORT, UINT, ULONG, ULONGLONG,
385
 *          BYTE, SHORT, INT, LONG, LONGLONG,
386
 *          BOOL, OBJECT#
387
 *  #typ = npy_float,npy_double,npy_longdouble, npy_half,
388
 *         npy_cfloat, npy_cdouble, npy_clongdouble,
389
 *         npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
390
 *         npy_byte, npy_short, npy_int, npy_long, npy_longlong,
391
 *         npy_bool,npy_object#
392
 * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*12#
393
 * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*13#
394
 */
395

396

397
NPY_NO_EXPORT void
398 1
@TYPE@_matmul(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
399
{
400 1
    npy_intp dOuter = *dimensions++;
401
    npy_intp iOuter;
402 1
    npy_intp s0 = *steps++;
403 1
    npy_intp s1 = *steps++;
404 1
    npy_intp s2 = *steps++;
405 1
    npy_intp dm = dimensions[0];
406 1
    npy_intp dn = dimensions[1];
407 1
    npy_intp dp = dimensions[2];
408 1
    npy_intp is1_m=steps[0], is1_n=steps[1], is2_n=steps[2], is2_p=steps[3],
409 1
         os_m=steps[4], os_p=steps[5];
410
#if @USEBLAS@ && defined(HAVE_CBLAS)
411 1
    npy_intp sz = sizeof(@typ@);
412 1
    npy_bool special_case = (dm == 1 || dn == 1 || dp == 1);
413 1
    npy_bool any_zero_dim = (dm == 0 || dn == 0 || dp == 0);
414 1
    npy_bool scalar_out = (dm == 1 && dp == 1);
415 1
    npy_bool scalar_vec = (dn == 1 && (dp == 1 || dm == 1));
416 1
    npy_bool too_big_for_blas = (dm > BLAS_MAXSIZE || dn > BLAS_MAXSIZE ||
417
                                 dp > BLAS_MAXSIZE);
418 1
    npy_bool i1_c_blasable = is_blasable2d(is1_m, is1_n, dm, dn, sz);
419 1
    npy_bool i2_c_blasable = is_blasable2d(is2_n, is2_p, dn, dp, sz);
420 1
    npy_bool i1_f_blasable = is_blasable2d(is1_n, is1_m, dn, dm, sz);
421 1
    npy_bool i2_f_blasable = is_blasable2d(is2_p, is2_n, dp, dn, sz);
422 1
    npy_bool i1blasable = i1_c_blasable || i1_f_blasable;
423 1
    npy_bool i2blasable = i2_c_blasable || i2_f_blasable;
424 1
    npy_bool o_c_blasable = is_blasable2d(os_m, os_p, dm, dp, sz);
425 1
    npy_bool o_f_blasable = is_blasable2d(os_p, os_m, dp, dm, sz);
426 1
    npy_bool vector_matrix = ((dm == 1) && i2blasable &&
427 1
                              is_blasable2d(is1_n, sz, dn, 1, sz));
428 1
    npy_bool matrix_vector = ((dp == 1)  && i1blasable &&
429 1
                              is_blasable2d(is2_n, sz, dn, 1, sz));
430
#endif
431

432 1
    for (iOuter = 0; iOuter < dOuter; iOuter++,
433 1
                         args[0] += s0, args[1] += s1, args[2] += s2) {
434 1
        void *ip1=args[0], *ip2=args[1], *op=args[2];
435
#if @USEBLAS@ && defined(HAVE_CBLAS)
436
        /*
437
         * TODO: refactor this out to a inner_loop_selector, in
438
         * PyUFunc_MatmulLoopSelector. But that call does not have access to
439
         * n, m, p and strides.
440
         */
441 1
        if (too_big_for_blas || any_zero_dim) {
442 1
            @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n, 
443
                                       ip2, is2_n, is2_p,
444
                                       op, os_m, os_p, dm, dn, dp);
445
        }
446 1
        else if (special_case) {
447
            /* Special case variants that have a 1 in the core dimensions */
448 1
            if (scalar_out) {
449
                /* row @ column, 1,1 output */
450 1
                @TYPE@_dot(ip1, is1_n, ip2, is2_n, op, dn, NULL);
451 1
            } else if (scalar_vec){
452
                /*
453
                 * 1,1d @ vector or vector @ 1,1d
454
                 * could use cblas_Xaxy, but that requires 0ing output
455
                 * and would not be faster (XXX prove it)
456
                 */
457 1
                @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n, 
458
                                           ip2, is2_n, is2_p,
459
                                           op, os_m, os_p, dm, dn, dp);
460 1
            } else if (vector_matrix) {
461
                /* vector @ matrix, switch ip1, ip2, p and m */
462 1
                @TYPE@_gemv(ip2, is2_p, is2_n, ip1, is1_n, is1_m,
463
                            op, os_p, os_m, dp, dn, dm);
464 1
            } else if  (matrix_vector) {
465
                /* matrix @ vector */
466 1
                @TYPE@_gemv(ip1, is1_m, is1_n, ip2, is2_n, is2_p,
467

468
                            op, os_m, os_p, dm, dn, dp);
469
            } else {
470
                /* column @ row, 2d output, no blas needed or non-blas-able input */
471 1
                @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n, 
472
                                           ip2, is2_n, is2_p,
473
                                           op, os_m, os_p, dm, dn, dp);
474
            }
475
        } else {
476
            /* matrix @ matrix */
477 1
            if (i1blasable && i2blasable && o_c_blasable) {
478 1
                @TYPE@_matmul_matrixmatrix(ip1, is1_m, is1_n,
479
                                           ip2, is2_n, is2_p,
480
                                           op, os_m, os_p,
481
                                           dm, dn, dp);
482 1
            } else if (i1blasable && i2blasable && o_f_blasable) {
483
                /*
484
                 * Use transpose equivalence:
485
                 * matmul(a, b, o) == matmul(b.T, a.T, o.T)
486
                 */
487 1
                @TYPE@_matmul_matrixmatrix(ip2, is2_p, is2_n,
488
                                           ip1, is1_n, is1_m,
489
                                           op, os_p, os_m,
490
                                           dp, dn, dm);
491
            } else {
492
                /*
493
                 * If parameters are castable to int and we copy the
494
                 * non-blasable (or non-ccontiguous output)
495
                 * we could still use BLAS, see gh-12365.
496
                 */
497 1
                @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n, 
498
                                           ip2, is2_n, is2_p,
499
                                           op, os_m, os_p, dm, dn, dp);
500
            }
501
        }
502
#else
503 1
        @TYPE@_matmul_inner_noblas(ip1, is1_m, is1_n, 
504
                                   ip2, is2_n, is2_p,
505
                                   op, os_m, os_p, dm, dn, dp);
506

507
#endif
508
    }
509 1
}
510

511
/**end repeat**/

Read our documentation on viewing source code .

Loading