1
/*
2
 * This file contains the implementation of the 'einsum' function,
3
 * which provides an einstein-summation operation.
4
 *
5
 * Copyright (c) 2011 by Mark Wiebe (mwwiebe@gmail.com)
6
 * The University of British Columbia
7
 *
8
 * See LICENSE.txt for the license.
9
 */
10

11
#define PY_SSIZE_T_CLEAN
12
#include "Python.h"
13
#include "structmember.h"
14

15
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
16
#define _MULTIARRAYMODULE
17
#include <numpy/npy_common.h>
18
#include <numpy/arrayobject.h>
19
#include <npy_pycompat.h>
20

21
#include <ctype.h>
22

23
#include "convert.h"
24
#include "common.h"
25
#include "ctors.h"
26

27
#include "einsum_sumprod.h"
28
#include "einsum_debug.h"
29

30

31
/*
32
 * Parses the subscripts for one operand into an output of 'ndim'
33
 * labels. The resulting 'op_labels' array will have:
34
 *  - the ASCII code of the label for the first occurrence of a label;
35
 *  - the (negative) offset to the first occurrence of the label for
36
 *    repeated labels;
37
 *  - zero for broadcast dimensions, if subscripts has an ellipsis.
38
 * For example:
39
 *  - subscripts="abbcbc",  ndim=6 -> op_labels=[97, 98, -1, 99, -3, -2]
40
 *  - subscripts="ab...bc", ndim=6 -> op_labels=[97, 98, 0, 0, -3, 99]
41
 */
42

43
static int
44 1
parse_operand_subscripts(char *subscripts, int length,
45
                         int ndim, int iop, char *op_labels,
46
                         char *label_counts, int *min_label, int *max_label)
47
{
48
    int i;
49 1
    int idim = 0;
50 1
    int ellipsis = -1;
51

52
    /* Process all labels for this operand */
53 1
    for (i = 0; i < length; ++i) {
54 1
        int label = subscripts[i];
55

56
        /* A proper label for an axis. */
57 1
        if (label > 0 && isalpha(label)) {
58
            /* Check we don't exceed the operator dimensions. */
59 1
            if (idim >= ndim) {
60 1
                PyErr_Format(PyExc_ValueError,
61
                             "einstein sum subscripts string contains "
62
                             "too many subscripts for operand %d", iop);
63 1
                return -1;
64
            }
65

66 1
            op_labels[idim++] = label;
67 1
            if (label < *min_label) {
68 1
                *min_label = label;
69
            }
70 1
            if (label > *max_label) {
71 1
                *max_label = label;
72
            }
73 1
            label_counts[label]++;
74
        }
75
        /* The beginning of the ellipsis. */
76 1
        else if (label == '.') {
77
            /* Check it's a proper ellipsis. */
78 1
            if (ellipsis != -1 || i + 2 >= length
79 1
                    || subscripts[++i] != '.' || subscripts[++i] != '.') {
80 1
                PyErr_Format(PyExc_ValueError,
81
                             "einstein sum subscripts string contains a "
82
                             "'.' that is not part of an ellipsis ('...') "
83
                             "in operand %d", iop);
84 1
                return -1;
85
            }
86

87
            ellipsis = idim;
88
        }
89 1
        else if (label != ' ') {
90 1
            PyErr_Format(PyExc_ValueError,
91
                         "invalid subscript '%c' in einstein sum "
92
                         "subscripts string, subscripts must "
93
                         "be letters", (char)label);
94 1
            return -1;
95
        }
96
    }
97

98
    /* No ellipsis found, labels must match dimensions exactly. */
99 1
    if (ellipsis == -1) {
100 1
        if (idim != ndim) {
101 1
            PyErr_Format(PyExc_ValueError,
102
                         "operand has more dimensions than subscripts "
103
                         "given in einstein sum, but no '...' ellipsis "
104
                         "provided to broadcast the extra dimensions.");
105 1
            return -1;
106
        }
107
    }
108
    /* Ellipsis found, may have to add broadcast dimensions. */
109 1
    else if (idim < ndim) {
110
        /* Move labels after ellipsis to the end. */
111 1
        for (i = 0; i < idim - ellipsis; ++i) {
112 1
            op_labels[ndim - i - 1] = op_labels[idim - i - 1];
113
        }
114
        /* Set all broadcast dimensions to zero. */
115 1
        for (i = 0; i < ndim - idim; ++i) {
116 1
            op_labels[ellipsis + i] = 0;
117
        }
118
    }
119

120
    /*
121
     * Find any labels duplicated for this operand, and turn them
122
     * into negative offsets to the axis to merge with.
123
     *
124
     * In C, the char type may be signed or unsigned, but with
125
     * twos complement arithmetic the char is ok either way here, and
126
     * later where it matters the char is cast to a signed char.
127
     */
128 1
    for (idim = 0; idim < ndim - 1; ++idim) {
129 1
        int label = (signed char)op_labels[idim];
130
        /* If it is a proper label, find any duplicates of it. */
131 1
        if (label > 0) {
132
            /* Search for the next matching label. */
133 1
            char *next = memchr(op_labels + idim + 1, label, ndim - idim - 1);
134

135 1
            while (next != NULL) {
136
                /* The offset from next to op_labels[idim] (negative). */
137 1
                *next = (char)((op_labels + idim) - next);
138
                /* Search for the next matching label. */
139 1
                next = memchr(next + 1, label, op_labels + ndim - 1 - next);
140
            }
141
        }
142
    }
143

144
    return 0;
145
}
146

147

148
/*
149
 * Parses the subscripts for the output operand into an output that
150
 * includes 'ndim_broadcast' unlabeled dimensions, and returns the total
151
 * number of output dimensions, or -1 if there is an error. Similarly
152
 * to parse_operand_subscripts, the 'out_labels' array will have, for
153
 * each dimension:
154
 *  - the ASCII code of the corresponding label;
155
 *  - zero for broadcast dimensions, if subscripts has an ellipsis.
156
 */
157
static int
158 1
parse_output_subscripts(char *subscripts, int length,
159
                        int ndim_broadcast,
160
                        const char *label_counts, char *out_labels)
161
{
162
    int i, bdim;
163 1
    int ndim = 0;
164 1
    int ellipsis = 0;
165

166
    /* Process all the output labels. */
167 1
    for (i = 0; i < length; ++i) {
168 1
        int label = subscripts[i];
169

170
        /* A proper label for an axis. */
171 1
        if (label > 0 && isalpha(label)) {
172
            /* Check that it doesn't occur again. */
173 1
            if (memchr(subscripts + i + 1, label, length - i - 1) != NULL) {
174 1
                PyErr_Format(PyExc_ValueError,
175
                             "einstein sum subscripts string includes "
176
                             "output subscript '%c' multiple times",
177
                             (char)label);
178 1
                return -1;
179
            }
180
            /* Check that it was used in the inputs. */
181 1
            if (label_counts[label] == 0) {
182 1
                PyErr_Format(PyExc_ValueError,
183
                             "einstein sum subscripts string included "
184
                             "output subscript '%c' which never appeared "
185
                             "in an input", (char)label);
186 1
                return -1;
187
            }
188
            /* Check that there is room in out_labels for this label. */
189 1
            if (ndim >= NPY_MAXDIMS) {
190 0
                PyErr_Format(PyExc_ValueError,
191
                             "einstein sum subscripts string contains "
192
                             "too many subscripts in the output");
193 0
                return -1;
194
            }
195

196 1
            out_labels[ndim++] = label;
197
        }
198
        /* The beginning of the ellipsis. */
199 1
        else if (label == '.') {
200
            /* Check it is a proper ellipsis. */
201 1
            if (ellipsis || i + 2 >= length
202 1
                    || subscripts[++i] != '.' || subscripts[++i] != '.') {
203 1
                PyErr_SetString(PyExc_ValueError,
204
                                "einstein sum subscripts string "
205
                                "contains a '.' that is not part of "
206
                                "an ellipsis ('...') in the output");
207 1
                return -1;
208
            }
209
            /* Check there is room in out_labels for broadcast dims. */
210 1
            if (ndim + ndim_broadcast > NPY_MAXDIMS) {
211 0
                PyErr_Format(PyExc_ValueError,
212
                             "einstein sum subscripts string contains "
213
                             "too many subscripts in the output");
214 0
                return -1;
215
            }
216

217
            ellipsis = 1;
218 1
            for (bdim = 0; bdim < ndim_broadcast; ++bdim) {
219 1
                out_labels[ndim++] = 0;
220
            }
221
        }
222 1
        else if (label != ' ') {
223 1
            PyErr_Format(PyExc_ValueError,
224
                         "invalid subscript '%c' in einstein sum "
225
                         "subscripts string, subscripts must "
226
                         "be letters", (char)label);
227 1
            return -1;
228
        }
229
    }
230

231
    /* If no ellipsis was found there should be no broadcast dimensions. */
232 1
    if (!ellipsis && ndim_broadcast > 0) {
233 0
        PyErr_SetString(PyExc_ValueError,
234
                        "output has more dimensions than subscripts "
235
                        "given in einstein sum, but no '...' ellipsis "
236
                        "provided to broadcast the extra dimensions.");
237 0
        return -1;
238
    }
239

240
    return ndim;
241
}
242

243

244
/*
245
 * When there's just one operand and no reduction we can return a view
246
 * into 'op'.  This calculates the view and stores it in 'ret', if
247
 * possible.  Returns -1 on error, 0 otherwise.  Note that a 0 return
248
 * does not mean that a view was successfully created.
249
 */
250
static int
251 1
get_single_op_view(PyArrayObject *op, char *labels,
252
                   int ndim_output, char *output_labels,
253
                   PyArrayObject **ret)
254
{
255
    npy_intp new_strides[NPY_MAXDIMS];
256
    npy_intp new_dims[NPY_MAXDIMS];
257
    char *out_label;
258 1
    int label, i, idim, ndim, ibroadcast = 0;
259

260 1
    ndim = PyArray_NDIM(op);
261

262
    /* Initialize the dimensions and strides to zero */
263 1
    for (idim = 0; idim < ndim_output; ++idim) {
264 1
        new_dims[idim] = 0;
265 1
        new_strides[idim] = 0;
266
    }
267

268
    /* Match the labels in the operand with the output labels */
269 1
    for (idim = 0; idim < ndim; ++idim) {
270
        /*
271
         * The char type may be either signed or unsigned, we
272
         * need it to be signed here.
273
         */
274 1
        label = (signed char)labels[idim];
275
        /* If this label says to merge axes, get the actual label */
276 1
        if (label < 0) {
277 1
            label = labels[idim+label];
278
        }
279
        /* If the label is 0, it's an unlabeled broadcast dimension */
280 1
        if (label == 0) {
281
            /* The next output label that's a broadcast dimension */
282 1
            for (; ibroadcast < ndim_output; ++ibroadcast) {
283 1
                if (output_labels[ibroadcast] == 0) {
284
                    break;
285
                }
286
            }
287 1
            if (ibroadcast == ndim_output) {
288 0
                PyErr_SetString(PyExc_ValueError,
289
                        "output had too few broadcast dimensions");
290 0
                return -1;
291
            }
292 1
            new_dims[ibroadcast] = PyArray_DIM(op, idim);
293 1
            new_strides[ibroadcast] = PyArray_STRIDE(op, idim);
294 1
            ++ibroadcast;
295
        }
296
        else {
297
            /* Find the position for this dimension in the output */
298 1
            out_label = (char *)memchr(output_labels, label,
299
                                                    ndim_output);
300
            /* If it's not found, reduction -> can't return a view */
301 1
            if (out_label == NULL) {
302
                break;
303
            }
304
            /* Update the dimensions and strides of the output */
305 1
            i = out_label - output_labels;
306 1
            if (new_dims[i] != 0 && new_dims[i] != PyArray_DIM(op, idim)) {
307 1
                PyErr_Format(PyExc_ValueError,
308
                        "dimensions in single operand for collapsing "
309
                        "index '%c' don't match (%d != %d)",
310 1
                        label, (int)new_dims[i], (int)PyArray_DIM(op, idim));
311 1
                return -1;
312
            }
313 1
            new_dims[i] = PyArray_DIM(op, idim);
314 1
            new_strides[i] += PyArray_STRIDE(op, idim);
315
        }
316
    }
317
    /* If we processed all the input axes, return a view */
318 1
    if (idim == ndim) {
319 1
        Py_INCREF(PyArray_DESCR(op));
320 1
        *ret = (PyArrayObject *)PyArray_NewFromDescr_int(
321 1
                Py_TYPE(op), PyArray_DESCR(op),
322
                ndim_output, new_dims, new_strides, PyArray_DATA(op),
323 1
                PyArray_ISWRITEABLE(op) ? NPY_ARRAY_WRITEABLE : 0,
324
                (PyObject *)op, (PyObject *)op,
325
                0, 0);
326

327 1
        if (*ret == NULL) {
328
            return -1;
329
        }
330 1
        return 0;
331
    }
332

333
    /* Return success, but that we couldn't make a view */
334 1
    *ret = NULL;
335 1
    return 0;
336
}
337

338

339
/*
340
 * The char type may be either signed or unsigned, we need it to be
341
 * signed here.
342
 */
343
static int
344
_any_labels_are_negative(signed char *labels, int ndim)
345
{
346
    int idim;
347

348 1
    for (idim = 0; idim < ndim; ++idim) {
349 1
        if (labels[idim] < 0) {
350
            return 1;
351
        }
352
    }
353

354
    return 0;
355
}
356

357
/*
358
 * Given the labels for an operand array, returns a view of the array
359
 * with all repeated labels collapsed into a single dimension along
360
 * the corresponding diagonal. The labels are also updated to match
361
 * the dimensions of the new array. If no label is repeated, the
362
 * original array is reference increased and returned unchanged.
363
 */
364
static PyArrayObject *
365 1
get_combined_dims_view(PyArrayObject *op, int iop, char *labels)
366
{
367
    npy_intp new_strides[NPY_MAXDIMS];
368
    npy_intp new_dims[NPY_MAXDIMS];
369
    int idim, icombine;
370
    int icombinemap[NPY_MAXDIMS];
371 1
    int ndim = PyArray_NDIM(op);
372 1
    PyArrayObject *ret = NULL;
373

374
    /* A fast path to avoid unnecessary calculations. */
375 1
    if (!_any_labels_are_negative((signed char *)labels, ndim)) {
376 1
        Py_INCREF(op);
377

378 1
        return op;
379
    }
380

381
    /* Combine repeated labels. */
382
    icombine = 0;
383 1
    for(idim = 0; idim < ndim; ++idim) {
384
        /*
385
         * The char type may be either signed or unsigned, we
386
         * need it to be signed here.
387
         */
388 1
        int label = (signed char)labels[idim];
389 1
        npy_intp dim = PyArray_DIM(op, idim);
390 1
        npy_intp stride = PyArray_STRIDE(op, idim);
391

392
        /* A label seen for the first time, add it to the op view. */
393 1
        if (label >= 0) {
394
            /*
395
             * icombinemap maps dimensions in the original array to
396
             * their position in the combined dimensions view.
397
             */
398 1
            icombinemap[idim] = icombine;
399 1
            new_dims[icombine] = dim;
400 1
            new_strides[icombine] = stride;
401 1
            ++icombine;
402
        }
403
        /* A repeated label, find the original one and merge them. */
404
        else {
405
#ifdef __GNUC__
406
#pragma GCC diagnostic push
407
#pragma GCC diagnostic ignored "-Wuninitialized"
408
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
409
#endif
410 1
            int i = icombinemap[idim + label];
411

412 1
            icombinemap[idim] = -1;
413 1
            if (new_dims[i] != dim) {
414 1
                char orig_label = labels[idim + label];
415 1
                PyErr_Format(PyExc_ValueError,
416
                             "dimensions in operand %d for collapsing "
417
                             "index '%c' don't match (%d != %d)",
418
                             iop, orig_label, (int)new_dims[i], (int)dim);
419 1
                return NULL;
420
            }
421 1
            new_strides[i] += stride;
422
#ifdef __GNUC__
423
#pragma GCC diagnostic pop
424
#endif
425
        }
426
    }
427

428
    /* Overwrite labels to match the new operand view. */
429 1
    for (idim = 0; idim < ndim; ++idim) {
430 1
        int i = icombinemap[idim];
431

432 1
        if (i >= 0) {
433 1
            labels[i] = labels[idim];
434
        }
435
    }
436

437
    /* The number of dimensions of the combined view. */
438 1
    ndim = icombine;
439

440
    /* Create a view of the operand with the compressed dimensions. */
441 1
    Py_INCREF(PyArray_DESCR(op));
442 1
    ret = (PyArrayObject *)PyArray_NewFromDescrAndBase(
443 1
            Py_TYPE(op), PyArray_DESCR(op),
444
            ndim, new_dims, new_strides, PyArray_DATA(op),
445 1
            PyArray_ISWRITEABLE(op) ? NPY_ARRAY_WRITEABLE : 0,
446
            (PyObject *)op, (PyObject *)op);
447

448 1
    return ret;
449
}
450

451
static int
452 1
prepare_op_axes(int ndim, int iop, char *labels, int *axes,
453
            int ndim_iter, char *iter_labels)
454
{
455
    int i, label, ibroadcast;
456

457 1
    ibroadcast = ndim-1;
458 1
    for (i = ndim_iter-1; i >= 0; --i) {
459 1
        label = iter_labels[i];
460
        /*
461
         * If it's an unlabeled broadcast dimension, choose
462
         * the next broadcast dimension from the operand.
463
         */
464 1
        if (label == 0) {
465 1
            while (ibroadcast >= 0 && labels[ibroadcast] != 0) {
466 1
                --ibroadcast;
467
            }
468
            /*
469
             * If we used up all the operand broadcast dimensions,
470
             * extend it with a "newaxis"
471
             */
472 1
            if (ibroadcast < 0) {
473 1
                axes[i] = -1;
474
            }
475
            /* Otherwise map to the broadcast axis */
476
            else {
477 1
                axes[i] = ibroadcast;
478 1
                --ibroadcast;
479
            }
480
        }
481
        /* It's a labeled dimension, find the matching one */
482
        else {
483 1
            char *match = memchr(labels, label, ndim);
484
            /* If the op doesn't have the label, broadcast it */
485 1
            if (match == NULL) {
486 1
                axes[i] = -1;
487
            }
488
            /* Otherwise use it */
489
            else {
490 1
                axes[i] = match - labels;
491
            }
492
        }
493
    }
494

495 1
    return 0;
496
}
497

498
static int
499 1
unbuffered_loop_nop1_ndim2(NpyIter *iter)
500
{
501
    npy_intp coord, shape[2], strides[2][2];
502
    char *ptrs[2][2], *ptr;
503
    sum_of_products_fn sop;
504 1
    NPY_BEGIN_THREADS_DEF;
505

506
#if NPY_EINSUM_DBG_TRACING
507
    NpyIter_DebugPrint(iter);
508
#endif
509
    NPY_EINSUM_DBG_PRINT("running hand-coded 1-op 2-dim loop\n");
510

511 1
    NpyIter_GetShape(iter, shape);
512 1
    memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0),
513
                                            2*sizeof(npy_intp));
514 1
    memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1),
515
                                            2*sizeof(npy_intp));
516 1
    memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter),
517
                                            2*sizeof(char *));
518 1
    memcpy(ptrs[1], ptrs[0], 2*sizeof(char*));
519

520 1
    sop = get_sum_of_products_function(1,
521 1
                    NpyIter_GetDescrArray(iter)[0]->type_num,
522 1
                    NpyIter_GetDescrArray(iter)[0]->elsize,
523
                    strides[0]);
524

525 1
    if (sop == NULL) {
526 0
        PyErr_SetString(PyExc_TypeError,
527
                    "invalid data type for einsum");
528 0
        return -1;
529
    }
530

531
    /*
532
     * Since the iterator wasn't tracking coordinates, the
533
     * loop provided by the iterator is in Fortran-order.
534
     */
535 1
    NPY_BEGIN_THREADS_THRESHOLDED(shape[1] * shape[0]);
536 1
    for (coord = shape[1]; coord > 0; --coord) {
537 1
        sop(1, ptrs[0], strides[0], shape[0]);
538

539 1
        ptr = ptrs[1][0] + strides[1][0];
540 1
        ptrs[0][0] = ptrs[1][0] = ptr;
541 1
        ptr = ptrs[1][1] + strides[1][1];
542 1
        ptrs[0][1] = ptrs[1][1] = ptr;
543
    }
544 1
    NPY_END_THREADS;
545

546
    return 0;
547
}
548

549
static int
550 0
unbuffered_loop_nop1_ndim3(NpyIter *iter)
551
{
552
    npy_intp coords[2], shape[3], strides[3][2];
553
    char *ptrs[3][2], *ptr;
554
    sum_of_products_fn sop;
555 0
    NPY_BEGIN_THREADS_DEF;
556

557
#if NPY_EINSUM_DBG_TRACING
558
    NpyIter_DebugPrint(iter);
559
#endif
560
    NPY_EINSUM_DBG_PRINT("running hand-coded 1-op 3-dim loop\n");
561

562 0
    NpyIter_GetShape(iter, shape);
563 0
    memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0),
564
                                            2*sizeof(npy_intp));
565 0
    memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1),
566
                                            2*sizeof(npy_intp));
567 0
    memcpy(strides[2], NpyIter_GetAxisStrideArray(iter, 2),
568
                                            2*sizeof(npy_intp));
569 0
    memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter),
570
                                            2*sizeof(char *));
571 0
    memcpy(ptrs[1], ptrs[0], 2*sizeof(char*));
572 0
    memcpy(ptrs[2], ptrs[0], 2*sizeof(char*));
573

574 0
    sop = get_sum_of_products_function(1,
575 0
                    NpyIter_GetDescrArray(iter)[0]->type_num,
576 0
                    NpyIter_GetDescrArray(iter)[0]->elsize,
577
                    strides[0]);
578

579 0
    if (sop == NULL) {
580 0
        PyErr_SetString(PyExc_TypeError,
581
                    "invalid data type for einsum");
582 0
        return -1;
583
    }
584

585
    /*
586
     * Since the iterator wasn't tracking coordinates, the
587
     * loop provided by the iterator is in Fortran-order.
588
     */
589 0
    NPY_BEGIN_THREADS_THRESHOLDED(shape[2] * shape[1] * shape[0]);
590 0
    for (coords[1] = shape[2]; coords[1] > 0; --coords[1]) {
591 0
        for (coords[0] = shape[1]; coords[0] > 0; --coords[0]) {
592 0
            sop(1, ptrs[0], strides[0], shape[0]);
593

594 0
            ptr = ptrs[1][0] + strides[1][0];
595 0
            ptrs[0][0] = ptrs[1][0] = ptr;
596 0
            ptr = ptrs[1][1] + strides[1][1];
597 0
            ptrs[0][1] = ptrs[1][1] = ptr;
598
        }
599 0
        ptr = ptrs[2][0] + strides[2][0];
600 0
        ptrs[0][0] = ptrs[1][0] = ptrs[2][0] = ptr;
601 0
        ptr = ptrs[2][1] + strides[2][1];
602 0
        ptrs[0][1] = ptrs[1][1] = ptrs[2][1] = ptr;
603
    }
604 0
    NPY_END_THREADS;
605

606
    return 0;
607
}
608

609
static int
610 1
unbuffered_loop_nop2_ndim2(NpyIter *iter)
611
{
612
    npy_intp coord, shape[2], strides[2][3];
613
    char *ptrs[2][3], *ptr;
614
    sum_of_products_fn sop;
615 1
    NPY_BEGIN_THREADS_DEF;
616

617
#if NPY_EINSUM_DBG_TRACING
618
    NpyIter_DebugPrint(iter);
619
#endif
620
    NPY_EINSUM_DBG_PRINT("running hand-coded 2-op 2-dim loop\n");
621

622 1
    NpyIter_GetShape(iter, shape);
623 1
    memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0),
624
                                            3*sizeof(npy_intp));
625 1
    memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1),
626
                                            3*sizeof(npy_intp));
627 1
    memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter),
628
                                            3*sizeof(char *));
629 1
    memcpy(ptrs[1], ptrs[0], 3*sizeof(char*));
630

631 1
    sop = get_sum_of_products_function(2,
632 1
                    NpyIter_GetDescrArray(iter)[0]->type_num,
633 1
                    NpyIter_GetDescrArray(iter)[0]->elsize,
634
                    strides[0]);
635

636 1
    if (sop == NULL) {
637 0
        PyErr_SetString(PyExc_TypeError,
638
                    "invalid data type for einsum");
639 0
        return -1;
640
    }
641

642
    /*
643
     * Since the iterator wasn't tracking coordinates, the
644
     * loop provided by the iterator is in Fortran-order.
645
     */
646 1
    NPY_BEGIN_THREADS_THRESHOLDED(shape[1] * shape[0]);
647 1
    for (coord = shape[1]; coord > 0; --coord) {
648 1
        sop(2, ptrs[0], strides[0], shape[0]);
649

650 1
        ptr = ptrs[1][0] + strides[1][0];
651 1
        ptrs[0][0] = ptrs[1][0] = ptr;
652 1
        ptr = ptrs[1][1] + strides[1][1];
653 1
        ptrs[0][1] = ptrs[1][1] = ptr;
654 1
        ptr = ptrs[1][2] + strides[1][2];
655 1
        ptrs[0][2] = ptrs[1][2] = ptr;
656
    }
657 1
    NPY_END_THREADS;
658

659
    return 0;
660
}
661

662
static int
663 1
unbuffered_loop_nop2_ndim3(NpyIter *iter)
664
{
665
    npy_intp coords[2], shape[3], strides[3][3];
666
    char *ptrs[3][3], *ptr;
667
    sum_of_products_fn sop;
668 1
    NPY_BEGIN_THREADS_DEF;
669

670
#if NPY_EINSUM_DBG_TRACING
671
    NpyIter_DebugPrint(iter);
672
#endif
673
    NPY_EINSUM_DBG_PRINT("running hand-coded 2-op 3-dim loop\n");
674

675 1
    NpyIter_GetShape(iter, shape);
676 1
    memcpy(strides[0], NpyIter_GetAxisStrideArray(iter, 0),
677
                                            3*sizeof(npy_intp));
678 1
    memcpy(strides[1], NpyIter_GetAxisStrideArray(iter, 1),
679
                                            3*sizeof(npy_intp));
680 1
    memcpy(strides[2], NpyIter_GetAxisStrideArray(iter, 2),
681
                                            3*sizeof(npy_intp));
682 1
    memcpy(ptrs[0], NpyIter_GetInitialDataPtrArray(iter),
683
                                            3*sizeof(char *));
684 1
    memcpy(ptrs[1], ptrs[0], 3*sizeof(char*));
685 1
    memcpy(ptrs[2], ptrs[0], 3*sizeof(char*));
686

687 1
    sop = get_sum_of_products_function(2,
688 1
                    NpyIter_GetDescrArray(iter)[0]->type_num,
689 1
                    NpyIter_GetDescrArray(iter)[0]->elsize,
690
                    strides[0]);
691

692 1
    if (sop == NULL) {
693 0
        PyErr_SetString(PyExc_TypeError,
694
                    "invalid data type for einsum");
695 0
        return -1;
696
    }
697

698
    /*
699
     * Since the iterator wasn't tracking coordinates, the
700
     * loop provided by the iterator is in Fortran-order.
701
     */
702 1
    NPY_BEGIN_THREADS_THRESHOLDED(shape[2] * shape[1] * shape[0]);
703 1
    for (coords[1] = shape[2]; coords[1] > 0; --coords[1]) {
704 1
        for (coords[0] = shape[1]; coords[0] > 0; --coords[0]) {
705 1
            sop(2, ptrs[0], strides[0], shape[0]);
706

707 1
            ptr = ptrs[1][0] + strides[1][0];
708 1
            ptrs[0][0] = ptrs[1][0] = ptr;
709 1
            ptr = ptrs[1][1] + strides[1][1];
710 1
            ptrs[0][1] = ptrs[1][1] = ptr;
711 1
            ptr = ptrs[1][2] + strides[1][2];
712 1
            ptrs[0][2] = ptrs[1][2] = ptr;
713
        }
714 1
        ptr = ptrs[2][0] + strides[2][0];
715 1
        ptrs[0][0] = ptrs[1][0] = ptrs[2][0] = ptr;
716 1
        ptr = ptrs[2][1] + strides[2][1];
717 1
        ptrs[0][1] = ptrs[1][1] = ptrs[2][1] = ptr;
718 1
        ptr = ptrs[2][2] + strides[2][2];
719 1
        ptrs[0][2] = ptrs[1][2] = ptrs[2][2] = ptr;
720
    }
721 1
    NPY_END_THREADS;
722

723
    return 0;
724
}
725

726

727
/*NUMPY_API
728
 * This function provides summation of array elements according to
729
 * the Einstein summation convention.  For example:
730
 *  - trace(a)        -> einsum("ii", a)
731
 *  - transpose(a)    -> einsum("ji", a)
732
 *  - multiply(a,b)   -> einsum(",", a, b)
733
 *  - inner(a,b)      -> einsum("i,i", a, b)
734
 *  - outer(a,b)      -> einsum("i,j", a, b)
735
 *  - matvec(a,b)     -> einsum("ij,j", a, b)
736
 *  - matmat(a,b)     -> einsum("ij,jk", a, b)
737
 *
738
 * subscripts: The string of subscripts for einstein summation.
739
 * nop:        The number of operands
740
 * op_in:      The array of operands
741
 * dtype:      Either NULL, or the data type to force the calculation as.
742
 * order:      The order for the calculation/the output axes.
743
 * casting:    What kind of casts should be permitted.
744
 * out:        Either NULL, or an array into which the output should be placed.
745
 *
746
 * By default, the labels get placed in alphabetical order
747
 * at the end of the output. So, if c = einsum("i,j", a, b)
748
 * then c[i,j] == a[i]*b[j], but if c = einsum("j,i", a, b)
749
 * then c[i,j] = a[j]*b[i].
750
 *
751
 * Alternatively, you can control the output order or prevent
752
 * an axis from being summed/force an axis to be summed by providing
753
 * indices for the output. This allows us to turn 'trace' into
754
 * 'diag', for example.
755
 *  - diag(a)         -> einsum("ii->i", a)
756
 *  - sum(a, axis=0)  -> einsum("i...->", a)
757
 *
758
 * Subscripts at the beginning and end may be specified by
759
 * putting an ellipsis "..." in the middle.  For example,
760
 * the function einsum("i...i", a) takes the diagonal of
761
 * the first and last dimensions of the operand, and
762
 * einsum("ij...,jk...->ik...") takes the matrix product using
763
 * the first two indices of each operand instead of the last two.
764
 *
765
 * When there is only one operand, no axes being summed, and
766
 * no output parameter, this function returns a view
767
 * into the operand instead of making a copy.
768
 */
769
NPY_NO_EXPORT PyArrayObject *
770 1
PyArray_EinsteinSum(char *subscripts, npy_intp nop,
771
                    PyArrayObject **op_in,
772
                    PyArray_Descr *dtype,
773
                    NPY_ORDER order, NPY_CASTING casting,
774
                    PyArrayObject *out)
775
{
776 1
    int iop, label, min_label = 127, max_label = 0;
777
    char label_counts[128];
778
    char op_labels[NPY_MAXARGS][NPY_MAXDIMS];
779
    char output_labels[NPY_MAXDIMS], *iter_labels;
780
    int idim, ndim_output, ndim_broadcast, ndim_iter;
781

782 1
    PyArrayObject *op[NPY_MAXARGS], *ret = NULL;
783
    PyArray_Descr *op_dtypes_array[NPY_MAXARGS], **op_dtypes;
784

785
    int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
786
    int *op_axes[NPY_MAXARGS];
787
    npy_uint32 iter_flags, op_flags[NPY_MAXARGS];
788

789
    NpyIter *iter;
790
    sum_of_products_fn sop;
791
    npy_intp fixed_strides[NPY_MAXARGS];
792

793
    /* nop+1 (+1 is for the output) must fit in NPY_MAXARGS */
794 1
    if (nop >= NPY_MAXARGS) {
795 0
        PyErr_SetString(PyExc_ValueError,
796
                    "too many operands provided to einstein sum function");
797 0
        return NULL;
798
    }
799 1
    else if (nop < 1) {
800 0
        PyErr_SetString(PyExc_ValueError,
801
                    "not enough operands provided to einstein sum function");
802 0
        return NULL;
803
    }
804

805
    /* Parse the subscripts string into label_counts and op_labels */
806 1
    memset(label_counts, 0, sizeof(label_counts));
807 1
    for (iop = 0; iop < nop; ++iop) {
808 1
        int length = (int)strcspn(subscripts, ",-");
809

810 1
        if (iop == nop-1 && subscripts[length] == ',') {
811 1
            PyErr_SetString(PyExc_ValueError,
812
                        "more operands provided to einstein sum function "
813
                        "than specified in the subscripts string");
814 1
            return NULL;
815
        }
816 1
        else if(iop < nop-1 && subscripts[length] != ',') {
817 1
            PyErr_SetString(PyExc_ValueError,
818
                        "fewer operands provided to einstein sum function "
819
                        "than specified in the subscripts string");
820 1
            return NULL;
821
        }
822

823 1
        if (parse_operand_subscripts(subscripts, length,
824 1
                        PyArray_NDIM(op_in[iop]),
825 1
                        iop, op_labels[iop], label_counts,
826
                        &min_label, &max_label) < 0) {
827
            return NULL;
828
        }
829

830
        /* Move subscripts to the start of the labels for the next op */
831 1
        subscripts += length;
832 1
        if (iop < nop-1) {
833 1
            subscripts++;
834
        }
835
    }
836

837
    /*
838
     * Find the number of broadcast dimensions, which is the maximum
839
     * number of labels == 0 in an op_labels array.
840
     */
841
    ndim_broadcast = 0;
842 1
    for (iop = 0; iop < nop; ++iop) {
843 1
        npy_intp count_zeros = 0;
844
        int ndim;
845 1
        char *labels = op_labels[iop];
846

847 1
        ndim = PyArray_NDIM(op_in[iop]);
848 1
        for (idim = 0; idim < ndim; ++idim) {
849 1
            if (labels[idim] == 0) {
850 1
                ++count_zeros;
851
            }
852
        }
853

854 1
        if (count_zeros > ndim_broadcast) {
855 1
            ndim_broadcast = count_zeros;
856
        }
857
    }
858

859
    /*
860
     * If there is no output signature, fill output_labels and ndim_output
861
     * using each label that appeared once, in alphabetical order.
862
     */
863 1
    if (subscripts[0] == '\0') {
864
        /* If no output was specified, always broadcast left, as usual. */
865 1
        for (ndim_output = 0; ndim_output < ndim_broadcast; ++ndim_output) {
866 1
            output_labels[ndim_output] = 0;
867
        }
868 1
        for (label = min_label; label <= max_label; ++label) {
869 1
            if (label_counts[label] == 1) {
870 1
                if (ndim_output < NPY_MAXDIMS) {
871 1
                    output_labels[ndim_output++] = label;
872
                }
873
                else {
874 0
                    PyErr_SetString(PyExc_ValueError,
875
                                "einstein sum subscript string has too many "
876
                                "distinct labels");
877 0
                    return NULL;
878
                }
879
            }
880
        }
881
    }
882
    else {
883 1
        if (subscripts[0] != '-' || subscripts[1] != '>') {
884 0
            PyErr_SetString(PyExc_ValueError,
885
                        "einstein sum subscript string does not "
886
                        "contain proper '->' output specified");
887 0
            return NULL;
888
        }
889 1
        subscripts += 2;
890

891
        /* Parse the output subscript string. */
892 1
        ndim_output = parse_output_subscripts(subscripts, strlen(subscripts),
893
                                        ndim_broadcast, label_counts,
894
                                        output_labels);
895 1
        if (ndim_output < 0) {
896
            return NULL;
897
        }
898
    }
899

900 1
    if (out != NULL && PyArray_NDIM(out) != ndim_output) {
901 0
        PyErr_Format(PyExc_ValueError,
902
                "out parameter does not have the correct number of "
903
                "dimensions, has %d but should have %d",
904
                (int)PyArray_NDIM(out), (int)ndim_output);
905 0
        return NULL;
906
    }
907

908
    /*
909
     * If there's just one operand and no output parameter,
910
     * first try remapping the axes to the output to return
911
     * a view instead of a copy.
912
     */
913 1
    if (nop == 1 && out == NULL) {
914 1
        ret = NULL;
915

916 1
        if (get_single_op_view(op_in[0], op_labels[0], ndim_output,
917
                               output_labels, &ret) < 0) {
918
            return NULL;
919
        }
920

921 1
        if (ret != NULL) {
922
            return ret;
923
        }
924
    }
925

926
    /* Set all the op references to NULL */
927 1
    for (iop = 0; iop < nop; ++iop) {
928 1
        op[iop] = NULL;
929
    }
930

931
    /*
932
     * Process all the input ops, combining dimensions into their
933
     * diagonal where specified.
934
     */
935 1
    for (iop = 0; iop < nop; ++iop) {
936 1
        char *labels = op_labels[iop];
937

938 1
        op[iop] = get_combined_dims_view(op_in[iop], iop, labels);
939 1
        if (op[iop] == NULL) {
940
            goto fail;
941
        }
942
    }
943

944
    /* Set the output op */
945 1
    op[nop] = out;
946

947
    /*
948
     * Set up the labels for the iterator (output + combined labels).
949
     * Can just share the output_labels memory, because iter_labels
950
     * is output_labels with some more labels appended.
951
     */
952 1
    iter_labels = output_labels;
953 1
    ndim_iter = ndim_output;
954 1
    for (label = min_label; label <= max_label; ++label) {
955 1
        if (label_counts[label] > 0 &&
956 1
                memchr(output_labels, label, ndim_output) == NULL) {
957 1
            if (ndim_iter >= NPY_MAXDIMS) {
958 0
                PyErr_SetString(PyExc_ValueError,
959
                            "too many subscripts in einsum");
960 0
                goto fail;
961
            }
962 1
            iter_labels[ndim_iter++] = label;
963
        }
964
    }
965

966
    /* Set up the op_axes for the iterator */
967 1
    for (iop = 0; iop < nop; ++iop) {
968 1
        op_axes[iop] = op_axes_arrays[iop];
969

970 1
        if (prepare_op_axes(PyArray_NDIM(op[iop]), iop, op_labels[iop],
971
                    op_axes[iop], ndim_iter, iter_labels) < 0) {
972
            goto fail;
973
        }
974
    }
975

976
    /* Set up the op_dtypes if dtype was provided */
977 1
    if (dtype == NULL) {
978
        op_dtypes = NULL;
979
    }
980
    else {
981
        op_dtypes = op_dtypes_array;
982 1
        for (iop = 0; iop <= nop; ++iop) {
983 1
            op_dtypes[iop] = dtype;
984
        }
985
    }
986

987
    /* Set the op_axes for the output */
988 1
    op_axes[nop] = op_axes_arrays[nop];
989 1
    for (idim = 0; idim < ndim_output; ++idim) {
990 1
        op_axes[nop][idim] = idim;
991
    }
992 1
    for (idim = ndim_output; idim < ndim_iter; ++idim) {
993 1
        op_axes[nop][idim] = NPY_ITER_REDUCTION_AXIS(-1);
994
    }
995

996
    /* Set the iterator per-op flags */
997

998 1
    for (iop = 0; iop < nop; ++iop) {
999 1
        op_flags[iop] = NPY_ITER_READONLY|
1000
                        NPY_ITER_NBO|
1001
                        NPY_ITER_ALIGNED;
1002
    }
1003 1
    op_flags[nop] = NPY_ITER_READWRITE|
1004
                    NPY_ITER_NBO|
1005
                    NPY_ITER_ALIGNED|
1006
                    NPY_ITER_ALLOCATE;
1007 1
    iter_flags = NPY_ITER_EXTERNAL_LOOP|
1008
            NPY_ITER_BUFFERED|
1009
            NPY_ITER_DELAY_BUFALLOC|
1010
            NPY_ITER_GROWINNER|
1011
            NPY_ITER_REFS_OK|
1012
            NPY_ITER_ZEROSIZE_OK;
1013 1
    if (out != NULL) {
1014 1
        iter_flags |= NPY_ITER_COPY_IF_OVERLAP;
1015
    }
1016 1
    if (dtype == NULL) {
1017 1
        iter_flags |= NPY_ITER_COMMON_DTYPE;
1018
    }
1019

1020
    /* Allocate the iterator */
1021 1
    iter = NpyIter_AdvancedNew(nop+1, op, iter_flags, order, casting, op_flags,
1022
                               op_dtypes, ndim_iter, op_axes, NULL, 0);
1023

1024 1
    if (iter == NULL) {
1025
        goto fail;
1026
    }
1027

1028
    /* Initialize the output to all zeros */
1029 1
    ret = NpyIter_GetOperandArray(iter)[nop];
1030 1
    if (PyArray_AssignZero(ret, NULL) < 0) {
1031
        goto fail;
1032
    }
1033

1034
    /***************************/
1035
    /*
1036
     * Acceleration for some specific loop structures. Note
1037
     * that with axis coalescing, inputs with more dimensions can
1038
     * be reduced to fit into these patterns.
1039
     */
1040 1
    if (!NpyIter_RequiresBuffering(iter)) {
1041 1
        int ndim = NpyIter_GetNDim(iter);
1042 1
        switch (nop) {
1043 1
            case 1:
1044 1
                if (ndim == 2) {
1045 1
                    if (unbuffered_loop_nop1_ndim2(iter) < 0) {
1046
                        goto fail;
1047
                    }
1048
                    goto finish;
1049
                }
1050 1
                else if (ndim == 3) {
1051 0
                    if (unbuffered_loop_nop1_ndim3(iter) < 0) {
1052
                        goto fail;
1053
                    }
1054
                    goto finish;
1055
                }
1056
                break;
1057 1
            case 2:
1058 1
                if (ndim == 2) {
1059 1
                    if (unbuffered_loop_nop2_ndim2(iter) < 0) {
1060
                        goto fail;
1061
                    }
1062
                    goto finish;
1063
                }
1064 1
                else if (ndim == 3) {
1065 1
                    if (unbuffered_loop_nop2_ndim3(iter) < 0) {
1066
                        goto fail;
1067
                    }
1068
                    goto finish;
1069
                }
1070
                break;
1071
        }
1072
    }
1073
    /***************************/
1074

1075 1
    if (NpyIter_Reset(iter, NULL) != NPY_SUCCEED) {
1076
        goto fail;
1077
    }
1078

1079
    /*
1080
     * Get an inner loop function, specializing it based on
1081
     * the strides that are fixed for the whole loop.
1082
     */
1083 1
    NpyIter_GetInnerFixedStrideArray(iter, fixed_strides);
1084 1
    sop = get_sum_of_products_function(nop,
1085 1
                        NpyIter_GetDescrArray(iter)[0]->type_num,
1086 1
                        NpyIter_GetDescrArray(iter)[0]->elsize,
1087
                        fixed_strides);
1088

1089
#if NPY_EINSUM_DBG_TRACING
1090
    NpyIter_DebugPrint(iter);
1091
#endif
1092

1093
    /* Finally, the main loop */
1094 1
    if (sop == NULL) {
1095 0
        PyErr_SetString(PyExc_TypeError,
1096
                    "invalid data type for einsum");
1097
    }
1098 1
    else if (NpyIter_GetIterSize(iter) != 0) {
1099
        NpyIter_IterNextFunc *iternext;
1100
        char **dataptr;
1101
        npy_intp *stride;
1102
        npy_intp *countptr;
1103 1
        NPY_BEGIN_THREADS_DEF;
1104

1105 1
        iternext = NpyIter_GetIterNext(iter, NULL);
1106 1
        if (iternext == NULL) {
1107 0
            NpyIter_Deallocate(iter);
1108 0
            goto fail;
1109
        }
1110 1
        dataptr = NpyIter_GetDataPtrArray(iter);
1111 1
        stride = NpyIter_GetInnerStrideArray(iter);
1112 1
        countptr = NpyIter_GetInnerLoopSizePtr(iter);
1113

1114 1
        NPY_BEGIN_THREADS_NDITER(iter);
1115
        NPY_EINSUM_DBG_PRINT("Einsum loop\n");
1116
        do {
1117 1
            sop(nop, dataptr, stride, *countptr);
1118 1
        } while(iternext(iter));
1119 1
        NPY_END_THREADS;
1120

1121
        /* If the API was needed, it may have thrown an error */
1122 1
        if (NpyIter_IterationNeedsAPI(iter) && PyErr_Occurred()) {
1123
            goto fail;
1124
        }
1125
    }
1126

1127 1
finish:
1128 1
    if (out != NULL) {
1129 1
        ret = out;
1130
    }
1131 1
    Py_INCREF(ret);
1132

1133 1
    NpyIter_Deallocate(iter);
1134 1
    for (iop = 0; iop < nop; ++iop) {
1135 1
        Py_DECREF(op[iop]);
1136
    }
1137

1138 1
    return ret;
1139

1140 1
fail:
1141 1
    for (iop = 0; iop < nop; ++iop) {
1142 1
        Py_XDECREF(op[iop]);
1143
    }
1144

1145
    return NULL;
1146
}

Read our documentation on viewing source code .

Loading