1
/* Fixed size rational numbers exposed to Python */
2

3
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
4

5
#include <Python.h>
6
#include <structmember.h>
7
#include <numpy/arrayobject.h>
8
#include <numpy/ufuncobject.h>
9
#include <numpy/npy_3kcompat.h>
10
#include <math.h>
11

12
#include "common.h"  /* for error_converting */
13

14

15
/* Relevant arithmetic exceptions */
16

17
/* Uncomment the following line to work around a bug in numpy */
18
/* #define ACQUIRE_GIL */
19

20
static void
21 1
set_overflow(void) {
22
#ifdef ACQUIRE_GIL
23
    /* Need to grab the GIL to dodge a bug in numpy */
24
    PyGILState_STATE state = PyGILState_Ensure();
25
#endif
26 1
    if (!PyErr_Occurred()) {
27 1
        PyErr_SetString(PyExc_OverflowError,
28
                "overflow in rational arithmetic");
29
    }
30
#ifdef ACQUIRE_GIL
31
    PyGILState_Release(state);
32
#endif
33 1
}
34

35
static void
36 0
set_zero_divide(void) {
37
#ifdef ACQUIRE_GIL
38
    /* Need to grab the GIL to dodge a bug in numpy */
39
    PyGILState_STATE state = PyGILState_Ensure();
40
#endif
41 0
    if (!PyErr_Occurred()) {
42 0
        PyErr_SetString(PyExc_ZeroDivisionError,
43
                "zero divide in rational arithmetic");
44
    }
45
#ifdef ACQUIRE_GIL
46
    PyGILState_Release(state);
47
#endif
48 0
}
49

50
/* Integer arithmetic utilities */
51

52
static NPY_INLINE npy_int32
53
safe_neg(npy_int32 x) {
54 0
    if (x==(npy_int32)1<<31) {
55 0
        set_overflow();
56
    }
57 0
    return -x;
58
}
59

60
static NPY_INLINE npy_int32
61
safe_abs32(npy_int32 x) {
62
    npy_int32 nx;
63 0
    if (x>=0) {
64
        return x;
65
    }
66 0
    nx = -x;
67 0
    if (nx<0) {
68 0
        set_overflow();
69
    }
70
    return nx;
71
}
72

73
static NPY_INLINE npy_int64
74
safe_abs64(npy_int64 x) {
75
    npy_int64 nx;
76 1
    if (x>=0) {
77
        return x;
78
    }
79 0
    nx = -x;
80 0
    if (nx<0) {
81 0
        set_overflow();
82
    }
83
    return nx;
84
}
85

86
static NPY_INLINE npy_int64
87 1
gcd(npy_int64 x, npy_int64 y) {
88 1
    x = safe_abs64(x);
89 1
    y = safe_abs64(y);
90 1
    if (x < y) {
91 1
        npy_int64 t = x;
92 1
        x = y;
93 1
        y = t;
94
    }
95 1
    while (y) {
96
        npy_int64 t;
97 1
        x = x%y;
98 1
        t = x;
99 1
        x = y;
100 1
        y = t;
101
    }
102 1
    return x;
103
}
104

105
static NPY_INLINE npy_int64
106 0
lcm(npy_int64 x, npy_int64 y) {
107
    npy_int64 lcm;
108 0
    if (!x || !y) {
109
        return 0;
110
    }
111 0
    x /= gcd(x,y);
112 0
    lcm = x*y;
113 0
    if (lcm/y!=x) {
114 0
        set_overflow();
115
    }
116
    return safe_abs64(lcm);
117
}
118

119
/* Fixed precision rational numbers */
120

121
typedef struct {
122
    /* numerator */
123
    npy_int32 n;
124
    /*
125
     * denominator minus one: numpy.zeros() uses memset(0) for non-object
126
     * types, so need to ensure that rational(0) has all zero bytes
127
     */
128
    npy_int32 dmm;
129
} rational;
130

131
static NPY_INLINE rational
132
make_rational_int(npy_int64 n) {
133 1
    rational r = {(npy_int32)n,0};
134 1
    if (r.n != n) {
135 0
        set_overflow();
136
    }
137
    return r;
138
}
139

140
static rational
141 1
make_rational_slow(npy_int64 n_, npy_int64 d_) {
142 1
    rational r = {0};
143 1
    if (!d_) {
144 0
        set_zero_divide();
145
    }
146
    else {
147 1
        npy_int64 g = gcd(n_,d_);
148
        npy_int32 d;
149 1
        n_ /= g;
150 1
        d_ /= g;
151 1
        r.n = (npy_int32)n_;
152 1
        d = (npy_int32)d_;
153 1
        if (r.n!=n_ || d!=d_) {
154 0
            set_overflow();
155
        }
156
        else {
157 1
            if (d <= 0) {
158 0
                d = -d;
159 0
                r.n = safe_neg(r.n);
160
            }
161 1
            r.dmm = d-1;
162
        }
163
    }
164 1
    return r;
165
}
166

167
static NPY_INLINE npy_int32
168
d(rational r) {
169 1
    return r.dmm+1;
170
}
171

172
/* Assumes d_ > 0 */
173
static rational
174 1
make_rational_fast(npy_int64 n_, npy_int64 d_) {
175 1
    npy_int64 g = gcd(n_,d_);
176
    rational r;
177 1
    n_ /= g;
178 1
    d_ /= g;
179 1
    r.n = (npy_int32)n_;
180 1
    r.dmm = (npy_int32)(d_-1);
181 1
    if (r.n!=n_ || r.dmm+1!=d_) {
182 0
        set_overflow();
183
    }
184 1
    return r;
185
}
186

187
static NPY_INLINE rational
188
rational_negative(rational r) {
189
    rational x;
190 0
    x.n = safe_neg(r.n);
191 0
    x.dmm = r.dmm;
192
    return x;
193
}
194

195
static NPY_INLINE rational
196
rational_add(rational x, rational y) {
197
    /*
198
     * Note that the numerator computation can never overflow int128_t,
199
     * since each term is strictly under 2**128/4 (since d > 0).
200
     */
201 1
    return make_rational_fast((npy_int64)x.n*d(y)+(npy_int64)d(x)*y.n,
202 1
        (npy_int64)d(x)*d(y));
203
}
204

205
static NPY_INLINE rational
206
rational_subtract(rational x, rational y) {
207
    /* We're safe from overflow as with + */
208 0
    return make_rational_fast((npy_int64)x.n*d(y)-(npy_int64)d(x)*y.n,
209 0
        (npy_int64)d(x)*d(y));
210
}
211

212
static NPY_INLINE rational
213
rational_multiply(rational x, rational y) {
214
    /* We're safe from overflow as with + */
215 0
    return make_rational_fast((npy_int64)x.n*y.n,(npy_int64)d(x)*d(y));
216
}
217

218
static NPY_INLINE rational
219
rational_divide(rational x, rational y) {
220 0
    return make_rational_slow((npy_int64)x.n*d(y),(npy_int64)d(x)*y.n);
221
}
222

223
static NPY_INLINE npy_int64
224
rational_floor(rational x) {
225
    /* Always round down */
226 0
    if (x.n>=0) {
227 0
        return x.n/d(x);
228
    }
229
    /*
230
     * This can be done without casting up to 64 bits, but it requires
231
     * working out all the sign cases
232
     */
233 0
    return -((-(npy_int64)x.n+d(x)-1)/d(x));
234
}
235

236
static NPY_INLINE npy_int64
237 0
rational_ceil(rational x) {
238 0
    return -rational_floor(rational_negative(x));
239
}
240

241
static NPY_INLINE rational
242 0
rational_remainder(rational x, rational y) {
243 0
    return rational_subtract(x, rational_multiply(y,make_rational_int(
244
                    rational_floor(rational_divide(x,y)))));
245
}
246

247
static NPY_INLINE rational
248
rational_abs(rational x) {
249
    rational y;
250 0
    y.n = safe_abs32(x.n);
251 0
    y.dmm = x.dmm;
252
    return y;
253
}
254

255
static NPY_INLINE npy_int64
256
rational_rint(rational x) {
257
    /*
258
     * Round towards nearest integer, moving exact half integers towards
259
     * zero
260
     */
261 0
    npy_int32 d_ = d(x);
262 0
    return (2*(npy_int64)x.n+(x.n<0?-d_:d_))/(2*(npy_int64)d_);
263
}
264

265
static NPY_INLINE int
266
rational_sign(rational x) {
267 0
    return x.n<0?-1:x.n==0?0:1;
268
}
269

270
static NPY_INLINE rational
271 0
rational_inverse(rational x) {
272 0
    rational y = {0};
273 0
    if (!x.n) {
274 0
        set_zero_divide();
275
    }
276
    else {
277
        npy_int32 d_;
278 0
        y.n = d(x);
279 0
        d_ = x.n;
280 0
        if (d_ <= 0) {
281 0
            d_ = safe_neg(d_);
282 0
            y.n = -y.n;
283
        }
284 0
        y.dmm = d_-1;
285
    }
286 0
    return y;
287
}
288

289
static NPY_INLINE int
290
rational_eq(rational x, rational y) {
291
    /*
292
     * Since we enforce d > 0, and store fractions in reduced form,
293
     * equality is easy.
294
     */
295 1
    return x.n==y.n && x.dmm==y.dmm;
296
}
297

298
static NPY_INLINE int
299
rational_ne(rational x, rational y) {
300 0
    return !rational_eq(x,y);
301
}
302

303
static NPY_INLINE int
304
rational_lt(rational x, rational y) {
305 0
    return (npy_int64)x.n*d(y) < (npy_int64)y.n*d(x);
306
}
307

308
static NPY_INLINE int
309
rational_gt(rational x, rational y) {
310
    return rational_lt(y,x);
311
}
312

313
static NPY_INLINE int
314
rational_le(rational x, rational y) {
315 0
    return !rational_lt(y,x);
316
}
317

318
static NPY_INLINE int
319
rational_ge(rational x, rational y) {
320 0
    return !rational_lt(x,y);
321
}
322

323
static NPY_INLINE npy_int32
324
rational_int(rational x) {
325 1
    return x.n/d(x);
326
}
327

328
static NPY_INLINE double
329
rational_double(rational x) {
330 1
    return (double)x.n/d(x);
331
}
332

333
static NPY_INLINE int
334
rational_nonzero(rational x) {
335 0
    return x.n!=0;
336
}
337

338
static int
339 0
scan_rational(const char** s, rational* x) {
340
    long n,d;
341
    int offset;
342
    const char* ss;
343 0
    if (sscanf(*s,"%ld%n",&n,&offset)<=0) {
344
        return 0;
345
    }
346 0
    ss = *s+offset;
347 0
    if (*ss!='/') {
348 0
        *s = ss;
349 0
        *x = make_rational_int(n);
350 0
        return 1;
351
    }
352 0
    ss++;
353 0
    if (sscanf(ss,"%ld%n",&d,&offset)<=0 || d<=0) {
354
        return 0;
355
    }
356 0
    *s = ss+offset;
357 0
    *x = make_rational_slow(n,d);
358 0
    return 1;
359
}
360

361
/* Expose rational to Python as a numpy scalar */
362

363
typedef struct {
364
    PyObject_HEAD
365
    rational r;
366
} PyRational;
367

368
static PyTypeObject PyRational_Type;
369

370
static NPY_INLINE int
371
PyRational_Check(PyObject* object) {
372 1
    return PyObject_IsInstance(object,(PyObject*)&PyRational_Type);
373
}
374

375
static PyObject*
376
PyRational_FromRational(rational x) {
377 1
    PyRational* p = (PyRational*)PyRational_Type.tp_alloc(&PyRational_Type,0);
378 1
    if (p) {
379 1
        p->r = x;
380
    }
381
    return (PyObject*)p;
382
}
383

384
static PyObject*
385 1
pyrational_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
386
    Py_ssize_t size;
387
    PyObject* x[2];
388 1
    long n[2]={0,1};
389
    int i;
390
    rational r;
391 1
    if (kwds && PyDict_Size(kwds)) {
392 0
        PyErr_SetString(PyExc_TypeError,
393
                "constructor takes no keyword arguments");
394 0
        return 0;
395
    }
396 1
    size = PyTuple_GET_SIZE(args);
397 1
    if (size > 2) {
398 0
        PyErr_SetString(PyExc_TypeError,
399
                "expected rational or numerator and optional denominator");
400 0
        return 0;
401
    }
402

403 1
    if (size == 1) {
404 1
        x[0] = PyTuple_GET_ITEM(args, 0);
405 1
        if (PyRational_Check(x[0])) {
406 0
            Py_INCREF(x[0]);
407 0
            return x[0];
408
        }
409
        // TODO: allow construction from unicode strings
410 1
        else if (PyBytes_Check(x[0])) {
411 0
            const char* s = PyBytes_AS_STRING(x[0]);
412
            rational x;
413 0
            if (scan_rational(&s,&x)) {
414
                const char* p;
415 0
                for (p = s; *p; p++) {
416 0
                    if (!isspace(*p)) {
417
                        goto bad;
418
                    }
419
                }
420
                return PyRational_FromRational(x);
421
            }
422 0
            bad:
423 0
            PyErr_Format(PyExc_ValueError,
424
                    "invalid rational literal '%s'",s);
425 0
            return 0;
426
        }
427
    }
428

429 1
    for (i=0; i<size; i++) {
430
        PyObject* y;
431
        int eq;
432 1
        x[i] = PyTuple_GET_ITEM(args, i);
433 1
        n[i] = PyLong_AsLong(x[i]);
434 1
        if (error_converting(n[i])) {
435 0
            if (PyErr_ExceptionMatches(PyExc_TypeError)) {
436 0
                PyErr_Format(PyExc_TypeError,
437
                        "expected integer %s, got %s",
438
                        (i ? "denominator" : "numerator"),
439 0
                        x[i]->ob_type->tp_name);
440
            }
441
            return 0;
442
        }
443
        /* Check that we had an exact integer */
444 1
        y = PyLong_FromLong(n[i]);
445 1
        if (!y) {
446
            return 0;
447
        }
448 1
        eq = PyObject_RichCompareBool(x[i],y,Py_EQ);
449 1
        Py_DECREF(y);
450 1
        if (eq<0) {
451
            return 0;
452
        }
453 1
        if (!eq) {
454 0
            PyErr_Format(PyExc_TypeError,
455
                    "expected integer %s, got %s",
456
                    (i ? "denominator" : "numerator"),
457 0
                    x[i]->ob_type->tp_name);
458 0
            return 0;
459
        }
460
    }
461 1
    r = make_rational_slow(n[0],n[1]);
462 1
    if (PyErr_Occurred()) {
463
        return 0;
464
    }
465
    return PyRational_FromRational(r);
466
}
467

468
/*
469
 * Returns Py_NotImplemented on most conversion failures, or raises an
470
 * overflow error for too long ints
471
 */
472
#define AS_RATIONAL(dst,object) \
473
    { \
474
        dst.n = 0; \
475
        if (PyRational_Check(object)) { \
476
            dst = ((PyRational*)object)->r; \
477
        } \
478
        else { \
479
            PyObject* y_; \
480
            int eq_; \
481
            long n_ = PyLong_AsLong(object); \
482
            if (error_converting(n_)) { \
483
                if (PyErr_ExceptionMatches(PyExc_TypeError)) { \
484
                    PyErr_Clear(); \
485
                    Py_INCREF(Py_NotImplemented); \
486
                    return Py_NotImplemented; \
487
                } \
488
                return 0; \
489
            } \
490
            y_ = PyLong_FromLong(n_); \
491
            if (!y_) { \
492
                return 0; \
493
            } \
494
            eq_ = PyObject_RichCompareBool(object,y_,Py_EQ); \
495
            Py_DECREF(y_); \
496
            if (eq_<0) { \
497
                return 0; \
498
            } \
499
            if (!eq_) { \
500
                Py_INCREF(Py_NotImplemented); \
501
                return Py_NotImplemented; \
502
            } \
503
            dst = make_rational_int(n_); \
504
        } \
505
    }
506

507
static PyObject*
508 0
pyrational_richcompare(PyObject* a, PyObject* b, int op) {
509
    rational x, y;
510 0
    int result = 0;
511 0
    AS_RATIONAL(x,a);
512 0
    AS_RATIONAL(y,b);
513
    #define OP(py,op) case py: result = rational_##op(x,y); break;
514 0
    switch (op) {
515 0
        OP(Py_LT,lt)
516 0
        OP(Py_LE,le)
517 0
        OP(Py_EQ,eq)
518 0
        OP(Py_NE,ne)
519 0
        OP(Py_GT,gt)
520 0
        OP(Py_GE,ge)
521
    };
522
    #undef OP
523 0
    return PyBool_FromLong(result);
524
}
525

526
static PyObject*
527 1
pyrational_repr(PyObject* self) {
528 1
    rational x = ((PyRational*)self)->r;
529 1
    if (d(x)!=1) {
530 1
        return PyUnicode_FromFormat(
531 1
                "rational(%ld,%ld)",(long)x.n,(long)d(x));
532
    }
533
    else {
534 0
        return PyUnicode_FromFormat(
535
                "rational(%ld)",(long)x.n);
536
    }
537
}
538

539
static PyObject*
540 0
pyrational_str(PyObject* self) {
541 0
    rational x = ((PyRational*)self)->r;
542 0
    if (d(x)!=1) {
543 0
        return PyUnicode_FromFormat(
544 0
                "%ld/%ld",(long)x.n,(long)d(x));
545
    }
546
    else {
547 0
        return PyUnicode_FromFormat(
548
                "%ld",(long)x.n);
549
    }
550
}
551

552
static npy_hash_t
553 0
pyrational_hash(PyObject* self) {
554 0
    rational x = ((PyRational*)self)->r;
555
    /* Use a fairly weak hash as Python expects */
556 0
    long h = 131071*x.n+524287*x.dmm;
557
    /* Never return the special error value -1 */
558 0
    return h==-1?2:h;
559
}
560

561
#define RATIONAL_BINOP_2(name,exp) \
562
    static PyObject* \
563
    pyrational_##name(PyObject* a, PyObject* b) { \
564
        rational x, y, z; \
565
        AS_RATIONAL(x,a); \
566
        AS_RATIONAL(y,b); \
567
        z = exp; \
568
        if (PyErr_Occurred()) { \
569
            return 0; \
570
        } \
571
        return PyRational_FromRational(z); \
572
    }
573
#define RATIONAL_BINOP(name) RATIONAL_BINOP_2(name,rational_##name(x,y))
574 0
RATIONAL_BINOP(add)
575 0
RATIONAL_BINOP(subtract)
576 0
RATIONAL_BINOP(multiply)
577 0
RATIONAL_BINOP(divide)
578 0
RATIONAL_BINOP(remainder)
579 0
RATIONAL_BINOP_2(floor_divide,
580
    make_rational_int(rational_floor(rational_divide(x,y))))
581

582
#define RATIONAL_UNOP(name,type,exp,convert) \
583
    static PyObject* \
584
    pyrational_##name(PyObject* self) { \
585
        rational x = ((PyRational*)self)->r; \
586
        type y = exp; \
587
        if (PyErr_Occurred()) { \
588
            return 0; \
589
        } \
590
        return convert(y); \
591
    }
592 0
RATIONAL_UNOP(negative,rational,rational_negative(x),PyRational_FromRational)
593 0
RATIONAL_UNOP(absolute,rational,rational_abs(x),PyRational_FromRational)
594 0
RATIONAL_UNOP(int,long,rational_int(x),PyLong_FromLong)
595 0
RATIONAL_UNOP(float,double,rational_double(x),PyFloat_FromDouble)
596

597
static PyObject*
598 0
pyrational_positive(PyObject* self) {
599 0
    Py_INCREF(self);
600 0
    return self;
601
}
602

603
static int
604 0
pyrational_nonzero(PyObject* self) {
605 0
    rational x = ((PyRational*)self)->r;
606 0
    return rational_nonzero(x);
607
}
608

609
static PyNumberMethods pyrational_as_number = {
610
    pyrational_add,          /* nb_add */
611
    pyrational_subtract,     /* nb_subtract */
612
    pyrational_multiply,     /* nb_multiply */
613
    pyrational_remainder,    /* nb_remainder */
614
    0,                       /* nb_divmod */
615
    0,                       /* nb_power */
616
    pyrational_negative,     /* nb_negative */
617
    pyrational_positive,     /* nb_positive */
618
    pyrational_absolute,     /* nb_absolute */
619
    pyrational_nonzero,      /* nb_nonzero */
620
    0,                       /* nb_invert */
621
    0,                       /* nb_lshift */
622
    0,                       /* nb_rshift */
623
    0,                       /* nb_and */
624
    0,                       /* nb_xor */
625
    0,                       /* nb_or */
626
    pyrational_int,          /* nb_int */
627
    0,                       /* reserved */
628
    pyrational_float,        /* nb_float */
629

630
    0,                       /* nb_inplace_add */
631
    0,                       /* nb_inplace_subtract */
632
    0,                       /* nb_inplace_multiply */
633
    0,                       /* nb_inplace_remainder */
634
    0,                       /* nb_inplace_power */
635
    0,                       /* nb_inplace_lshift */
636
    0,                       /* nb_inplace_rshift */
637
    0,                       /* nb_inplace_and */
638
    0,                       /* nb_inplace_xor */
639
    0,                       /* nb_inplace_or */
640

641
    pyrational_floor_divide, /* nb_floor_divide */
642
    pyrational_divide,       /* nb_true_divide */
643
    0,                       /* nb_inplace_floor_divide */
644
    0,                       /* nb_inplace_true_divide */
645
    0,                       /* nb_index */
646
};
647

648
static PyObject*
649 0
pyrational_n(PyObject* self, void* closure) {
650 0
    return PyLong_FromLong(((PyRational*)self)->r.n);
651
}
652

653
static PyObject*
654 0
pyrational_d(PyObject* self, void* closure) {
655 0
    return PyLong_FromLong(d(((PyRational*)self)->r));
656
}
657

658
static PyGetSetDef pyrational_getset[] = {
659
    {(char*)"n",pyrational_n,0,(char*)"numerator",0},
660
    {(char*)"d",pyrational_d,0,(char*)"denominator",0},
661
    {0} /* sentinel */
662
};
663

664
static PyTypeObject PyRational_Type = {
665
    PyVarObject_HEAD_INIT(NULL, 0)
666
    "rational",                               /* tp_name */
667
    sizeof(PyRational),                       /* tp_basicsize */
668
    0,                                        /* tp_itemsize */
669
    0,                                        /* tp_dealloc */
670
    0,                                        /* tp_print */
671
    0,                                        /* tp_getattr */
672
    0,                                        /* tp_setattr */
673
    0,                                        /* tp_reserved */
674
    pyrational_repr,                          /* tp_repr */
675
    &pyrational_as_number,                    /* tp_as_number */
676
    0,                                        /* tp_as_sequence */
677
    0,                                        /* tp_as_mapping */
678
    pyrational_hash,                          /* tp_hash */
679
    0,                                        /* tp_call */
680
    pyrational_str,                           /* tp_str */
681
    0,                                        /* tp_getattro */
682
    0,                                        /* tp_setattro */
683
    0,                                        /* tp_as_buffer */
684
    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
685
    "Fixed precision rational numbers",       /* tp_doc */
686
    0,                                        /* tp_traverse */
687
    0,                                        /* tp_clear */
688
    pyrational_richcompare,                   /* tp_richcompare */
689
    0,                                        /* tp_weaklistoffset */
690
    0,                                        /* tp_iter */
691
    0,                                        /* tp_iternext */
692
    0,                                        /* tp_methods */
693
    0,                                        /* tp_members */
694
    pyrational_getset,                        /* tp_getset */
695
    0,                                        /* tp_base */
696
    0,                                        /* tp_dict */
697
    0,                                        /* tp_descr_get */
698
    0,                                        /* tp_descr_set */
699
    0,                                        /* tp_dictoffset */
700
    0,                                        /* tp_init */
701
    0,                                        /* tp_alloc */
702
    pyrational_new,                           /* tp_new */
703
    0,                                        /* tp_free */
704
    0,                                        /* tp_is_gc */
705
    0,                                        /* tp_bases */
706
    0,                                        /* tp_mro */
707
    0,                                        /* tp_cache */
708
    0,                                        /* tp_subclasses */
709
    0,                                        /* tp_weaklist */
710
    0,                                        /* tp_del */
711
    0,                                        /* tp_version_tag */
712
};
713

714
/* NumPy support */
715

716
static PyObject*
717 0
npyrational_getitem(void* data, void* arr) {
718
    rational r;
719 0
    memcpy(&r,data,sizeof(rational));
720 0
    return PyRational_FromRational(r);
721
}
722

723
static int
724 1
npyrational_setitem(PyObject* item, void* data, void* arr) {
725
    rational r;
726 1
    if (PyRational_Check(item)) {
727 1
        r = ((PyRational*)item)->r;
728
    }
729
    else {
730 1
        long long n = PyLong_AsLongLong(item);
731
        PyObject* y;
732
        int eq;
733 1
        if (error_converting(n)) {
734
            return -1;
735
        }
736 1
        y = PyLong_FromLongLong(n);
737 1
        if (!y) {
738
            return -1;
739
        }
740 1
        eq = PyObject_RichCompareBool(item, y, Py_EQ);
741 1
        Py_DECREF(y);
742 1
        if (eq<0) {
743
            return -1;
744
        }
745 1
        if (!eq) {
746 0
            PyErr_Format(PyExc_TypeError,
747 0
                    "expected rational, got %s", item->ob_type->tp_name);
748 0
            return -1;
749
        }
750 1
        r = make_rational_int(n);
751
    }
752 1
    memcpy(data, &r, sizeof(rational));
753 1
    return 0;
754
}
755

756
static NPY_INLINE void
757
byteswap(npy_int32* x) {
758 0
    char* p = (char*)x;
759
    size_t i;
760 0
    for (i = 0; i < sizeof(*x)/2; i++) {
761 0
        size_t j = sizeof(*x)-1-i;
762 0
        char t = p[i];
763 0
        p[i] = p[j];
764 0
        p[j] = t;
765
    }
766
}
767

768
static void
769 1
npyrational_copyswapn(void* dst_, npy_intp dstride, void* src_,
770
        npy_intp sstride, npy_intp n, int swap, void* arr) {
771 1
    char *dst = (char*)dst_, *src = (char*)src_;
772
    npy_intp i;
773 1
    if (!src) {
774
        return;
775
    }
776 1
    if (swap) {
777 0
        for (i = 0; i < n; i++) {
778 0
            rational* r = (rational*)(dst+dstride*i);
779 0
            memcpy(r,src+sstride*i,sizeof(rational));
780 0
            byteswap(&r->n);
781 0
            byteswap(&r->dmm);
782
        }
783
    }
784 1
    else if (dstride == sizeof(rational) && sstride == sizeof(rational)) {
785 1
        memcpy(dst, src, n*sizeof(rational));
786
    }
787
    else {
788 1
        for (i = 0; i < n; i++) {
789 1
            memcpy(dst + dstride*i, src + sstride*i, sizeof(rational));
790
        }
791
    }
792
}
793

794
static void
795 0
npyrational_copyswap(void* dst, void* src, int swap, void* arr) {
796
    rational* r;
797 0
    if (!src) {
798
        return;
799
    }
800 0
    r = (rational*)dst;
801 0
    memcpy(r,src,sizeof(rational));
802 0
    if (swap) {
803 0
        byteswap(&r->n);
804 0
        byteswap(&r->dmm);
805
    }
806
}
807

808
static int
809 0
npyrational_compare(const void* d0, const void* d1, void* arr) {
810 0
    rational x = *(rational*)d0,
811 0
             y = *(rational*)d1;
812 0
    return rational_lt(x,y)?-1:rational_eq(x,y)?0:1;
813
}
814

815
#define FIND_EXTREME(name,op) \
816
    static int \
817
    npyrational_##name(void* data_, npy_intp n, \
818
            npy_intp* max_ind, void* arr) { \
819
        const rational* data; \
820
        npy_intp best_i; \
821
        rational best_r; \
822
        npy_intp i; \
823
        if (!n) { \
824
            return 0; \
825
        } \
826
        data = (rational*)data_; \
827
        best_i = 0; \
828
        best_r = data[0]; \
829
        for (i = 1; i < n; i++) { \
830
            if (rational_##op(data[i],best_r)) { \
831
                best_i = i; \
832
                best_r = data[i]; \
833
            } \
834
        } \
835
        *max_ind = best_i; \
836
        return 0; \
837
    }
838 0
FIND_EXTREME(argmin,lt)
839 0
FIND_EXTREME(argmax,gt)
840

841
static void
842 0
npyrational_dot(void* ip0_, npy_intp is0, void* ip1_, npy_intp is1,
843
        void* op, npy_intp n, void* arr) {
844 0
    rational r = {0};
845 0
    const char *ip0 = (char*)ip0_, *ip1 = (char*)ip1_;
846
    npy_intp i;
847 0
    for (i = 0; i < n; i++) {
848 0
        r = rational_add(r,rational_multiply(*(rational*)ip0,*(rational*)ip1));
849 0
        ip0 += is0;
850 0
        ip1 += is1;
851
    }
852 0
    *(rational*)op = r;
853 0
}
854

855
static npy_bool
856 0
npyrational_nonzero(void* data, void* arr) {
857
    rational r;
858 0
    memcpy(&r,data,sizeof(r));
859 0
    return rational_nonzero(r)?NPY_TRUE:NPY_FALSE;
860
}
861

862
static int
863 0
npyrational_fill(void* data_, npy_intp length, void* arr) {
864 0
    rational* data = (rational*)data_;
865 0
    rational delta = rational_subtract(data[1],data[0]);
866 0
    rational r = data[1];
867
    npy_intp i;
868 0
    for (i = 2; i < length; i++) {
869 0
        r = rational_add(r,delta);
870 0
        data[i] = r;
871
    }
872 0
    return 0;
873
}
874

875
static int
876 0
npyrational_fillwithscalar(void* buffer_, npy_intp length,
877
        void* value, void* arr) {
878 0
    rational r = *(rational*)value;
879 0
    rational* buffer = (rational*)buffer_;
880
    npy_intp i;
881 0
    for (i = 0; i < length; i++) {
882 0
        buffer[i] = r;
883
    }
884 0
    return 0;
885
}
886

887
static PyArray_ArrFuncs npyrational_arrfuncs;
888

889
typedef struct { char c; rational r; } align_test;
890

891
PyArray_Descr npyrational_descr = {
892
    PyObject_HEAD_INIT(0)
893
    &PyRational_Type,       /* typeobj */
894
    'V',                    /* kind */
895
    'r',                    /* type */
896
    '=',                    /* byteorder */
897
    /*
898
     * For now, we need NPY_NEEDS_PYAPI in order to make numpy detect our
899
     * exceptions.  This isn't technically necessary,
900
     * since we're careful about thread safety, and hopefully future
901
     * versions of numpy will recognize that.
902
     */
903
    NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM, /* hasobject */
904
    0,                      /* type_num */
905
    sizeof(rational),       /* elsize */
906
    offsetof(align_test,r), /* alignment */
907
    0,                      /* subarray */
908
    0,                      /* fields */
909
    0,                      /* names */
910
    &npyrational_arrfuncs,  /* f */
911
};
912

913
#define DEFINE_CAST(From,To,statement) \
914
    static void \
915
    npycast_##From##_##To(void* from_, void* to_, npy_intp n, \
916
                          void* fromarr, void* toarr) { \
917
        const From* from = (From*)from_; \
918
        To* to = (To*)to_; \
919
        npy_intp i; \
920
        for (i = 0; i < n; i++) { \
921
            From x = from[i]; \
922
            statement \
923
            to[i] = y; \
924
        } \
925
    }
926
#define DEFINE_INT_CAST(bits) \
927
    DEFINE_CAST(npy_int##bits,rational,rational y = make_rational_int(x);) \
928
    DEFINE_CAST(rational,npy_int##bits,npy_int32 z = rational_int(x); \
929
                npy_int##bits y = z; if (y != z) set_overflow();)
930 1
DEFINE_INT_CAST(8)
931 1
DEFINE_INT_CAST(16)
932 1
DEFINE_INT_CAST(32)
933 1
DEFINE_INT_CAST(64)
934 1
DEFINE_CAST(rational,float,double y = rational_double(x);)
935 1
DEFINE_CAST(rational,double,double y = rational_double(x);)
936 0
DEFINE_CAST(npy_bool,rational,rational y = make_rational_int(x);)
937 0
DEFINE_CAST(rational,npy_bool,npy_bool y = rational_nonzero(x);)
938

939
#define BINARY_UFUNC(name,intype0,intype1,outtype,exp) \
940
    void name(char** args, npy_intp const *dimensions, \
941
              npy_intp const *steps, void* data) { \
942
        npy_intp is0 = steps[0], is1 = steps[1], \
943
            os = steps[2], n = *dimensions; \
944
        char *i0 = args[0], *i1 = args[1], *o = args[2]; \
945
        int k; \
946
        for (k = 0; k < n; k++) { \
947
            intype0 x = *(intype0*)i0; \
948
            intype1 y = *(intype1*)i1; \
949
            *(outtype*)o = exp; \
950
            i0 += is0; i1 += is1; o += os; \
951
        } \
952
    }
953
#define RATIONAL_BINARY_UFUNC(name,type,exp) \
954
    BINARY_UFUNC(rational_ufunc_##name,rational,rational,type,exp)
955 0
RATIONAL_BINARY_UFUNC(add,rational,rational_add(x,y))
956 0
RATIONAL_BINARY_UFUNC(subtract,rational,rational_subtract(x,y))
957 0
RATIONAL_BINARY_UFUNC(multiply,rational,rational_multiply(x,y))
958 0
RATIONAL_BINARY_UFUNC(divide,rational,rational_divide(x,y))
959 0
RATIONAL_BINARY_UFUNC(remainder,rational,rational_remainder(x,y))
960 0
RATIONAL_BINARY_UFUNC(floor_divide,rational,
961
    make_rational_int(rational_floor(rational_divide(x,y))))
962
PyUFuncGenericFunction rational_ufunc_true_divide = rational_ufunc_divide;
963 0
RATIONAL_BINARY_UFUNC(minimum,rational,rational_lt(x,y)?x:y)
964 0
RATIONAL_BINARY_UFUNC(maximum,rational,rational_lt(x,y)?y:x)
965 1
RATIONAL_BINARY_UFUNC(equal,npy_bool,rational_eq(x,y))
966 0
RATIONAL_BINARY_UFUNC(not_equal,npy_bool,rational_ne(x,y))
967 0
RATIONAL_BINARY_UFUNC(less,npy_bool,rational_lt(x,y))
968 0
RATIONAL_BINARY_UFUNC(greater,npy_bool,rational_gt(x,y))
969 0
RATIONAL_BINARY_UFUNC(less_equal,npy_bool,rational_le(x,y))
970 0
RATIONAL_BINARY_UFUNC(greater_equal,npy_bool,rational_ge(x,y))
971

972 0
BINARY_UFUNC(gcd_ufunc,npy_int64,npy_int64,npy_int64,gcd(x,y))
973 0
BINARY_UFUNC(lcm_ufunc,npy_int64,npy_int64,npy_int64,lcm(x,y))
974

975
#define UNARY_UFUNC(name,type,exp) \
976
    void rational_ufunc_##name(char** args, npy_intp const *dimensions, \
977
                               npy_intp const *steps, void* data) { \
978
        npy_intp is = steps[0], os = steps[1], n = *dimensions; \
979
        char *i = args[0], *o = args[1]; \
980
        int k; \
981
        for (k = 0; k < n; k++) { \
982
            rational x = *(rational*)i; \
983
            *(type*)o = exp; \
984
            i += is; o += os; \
985
        } \
986
    }
987 0
UNARY_UFUNC(negative,rational,rational_negative(x))
988 0
UNARY_UFUNC(absolute,rational,rational_abs(x))
989 0
UNARY_UFUNC(floor,rational,make_rational_int(rational_floor(x)))
990 0
UNARY_UFUNC(ceil,rational,make_rational_int(rational_ceil(x)))
991 0
UNARY_UFUNC(trunc,rational,make_rational_int(x.n/d(x)))
992 0
UNARY_UFUNC(square,rational,rational_multiply(x,x))
993 0
UNARY_UFUNC(rint,rational,make_rational_int(rational_rint(x)))
994 0
UNARY_UFUNC(sign,rational,make_rational_int(rational_sign(x)))
995 0
UNARY_UFUNC(reciprocal,rational,rational_inverse(x))
996 0
UNARY_UFUNC(numerator,npy_int64,x.n)
997 0
UNARY_UFUNC(denominator,npy_int64,d(x))
998

999
static NPY_INLINE void
1000 0
rational_matrix_multiply(char **args, npy_intp const *dimensions, npy_intp const *steps)
1001
{
1002
    /* pointers to data for input and output arrays */
1003 0
    char *ip1 = args[0];
1004 0
    char *ip2 = args[1];
1005 0
    char *op = args[2];
1006

1007
    /* lengths of core dimensions */
1008 0
    npy_intp dm = dimensions[0];
1009 0
    npy_intp dn = dimensions[1];
1010 0
    npy_intp dp = dimensions[2];
1011

1012
    /* striding over core dimensions */
1013 0
    npy_intp is1_m = steps[0];
1014 0
    npy_intp is1_n = steps[1];
1015 0
    npy_intp is2_n = steps[2];
1016 0
    npy_intp is2_p = steps[3];
1017 0
    npy_intp os_m = steps[4];
1018 0
    npy_intp os_p = steps[5];
1019

1020
    /* core dimensions counters */
1021
    npy_intp m, p;
1022

1023
    /* calculate dot product for each row/column vector pair */
1024 0
    for (m = 0; m < dm; m++) {
1025 0
        for (p = 0; p < dp; p++) {
1026 0
            npyrational_dot(ip1, is1_n, ip2, is2_n, op, dn, NULL);
1027

1028
            /* advance to next column of 2nd input array and output array */
1029 0
            ip2 += is2_p;
1030 0
            op  +=  os_p;
1031
        }
1032

1033
        /* reset to first column of 2nd input array and output array */
1034 0
        ip2 -= is2_p * p;
1035 0
        op -= os_p * p;
1036

1037
        /* advance to next row of 1st input array and output array */
1038 0
        ip1 += is1_m;
1039 0
        op += os_m;
1040
    }
1041 0
}
1042

1043

1044
static void
1045 0
rational_gufunc_matrix_multiply(char **args, npy_intp const *dimensions,
1046
                                npy_intp const *steps, void *NPY_UNUSED(func))
1047
{
1048
    /* outer dimensions counter */
1049
    npy_intp N_;
1050

1051
    /* length of flattened outer dimensions */
1052 0
    npy_intp dN = dimensions[0];
1053

1054
    /* striding over flattened outer dimensions for input and output arrays */
1055 0
    npy_intp s0 = steps[0];
1056 0
    npy_intp s1 = steps[1];
1057 0
    npy_intp s2 = steps[2];
1058

1059
    /*
1060
     * loop through outer dimensions, performing matrix multiply on
1061
     * core dimensions for each loop
1062
     */
1063 0
    for (N_ = 0; N_ < dN; N_++, args[0] += s0, args[1] += s1, args[2] += s2) {
1064 0
        rational_matrix_multiply(args, dimensions+1, steps+3);
1065
    }
1066 0
}
1067

1068

1069
static void
1070 1
rational_ufunc_test_add(char** args, npy_intp const *dimensions,
1071
                        npy_intp const *steps, void* data) {
1072 1
    npy_intp is0 = steps[0], is1 = steps[1], os = steps[2], n = *dimensions;
1073 1
    char *i0 = args[0], *i1 = args[1], *o = args[2];
1074
    int k;
1075 1
    for (k = 0; k < n; k++) {
1076 1
        npy_int64 x = *(npy_int64*)i0;
1077 1
        npy_int64 y = *(npy_int64*)i1;
1078 1
        *(rational*)o = rational_add(make_rational_fast(x, 1),
1079
                                     make_rational_fast(y, 1));
1080 1
        i0 += is0; i1 += is1; o += os;
1081
    }
1082 1
}
1083

1084

1085
static void
1086 1
rational_ufunc_test_add_rationals(char** args, npy_intp const *dimensions,
1087
                        npy_intp const *steps, void* data) {
1088 1
    npy_intp is0 = steps[0], is1 = steps[1], os = steps[2], n = *dimensions;
1089 1
    char *i0 = args[0], *i1 = args[1], *o = args[2];
1090
    int k;
1091 1
    for (k = 0; k < n; k++) {
1092 1
        rational x = *(rational*)i0;
1093 1
        rational y = *(rational*)i1;
1094 1
        *(rational*)o = rational_add(x, y);
1095 1
        i0 += is0; i1 += is1; o += os;
1096
    }
1097 1
}
1098

1099

1100
PyMethodDef module_methods[] = {
1101
    {0} /* sentinel */
1102
};
1103

1104
static struct PyModuleDef moduledef = {
1105
    PyModuleDef_HEAD_INIT,
1106
    "_rational_tests",
1107
    NULL,
1108
    -1,
1109
    module_methods,
1110
    NULL,
1111
    NULL,
1112
    NULL,
1113
    NULL
1114
};
1115

1116 1
PyMODINIT_FUNC PyInit__rational_tests(void) {
1117 1
    PyObject *m = NULL;
1118
    PyObject* numpy_str;
1119
    PyObject* numpy;
1120
    int npy_rational;
1121

1122 1
    import_array();
1123 1
    if (PyErr_Occurred()) {
1124
        goto fail;
1125
    }
1126 1
    import_umath();
1127 1
    if (PyErr_Occurred()) {
1128
        goto fail;
1129
    }
1130 1
    numpy_str = PyUnicode_FromString("numpy");
1131 1
    if (!numpy_str) {
1132
        goto fail;
1133
    }
1134 1
    numpy = PyImport_Import(numpy_str);
1135 1
    Py_DECREF(numpy_str);
1136 1
    if (!numpy) {
1137
        goto fail;
1138
    }
1139

1140
    /* Can't set this until we import numpy */
1141 1
    PyRational_Type.tp_base = &PyGenericArrType_Type;
1142

1143
    /* Initialize rational type object */
1144 1
    if (PyType_Ready(&PyRational_Type) < 0) {
1145
        goto fail;
1146
    }
1147

1148
    /* Initialize rational descriptor */
1149 1
    PyArray_InitArrFuncs(&npyrational_arrfuncs);
1150 1
    npyrational_arrfuncs.getitem = npyrational_getitem;
1151 1
    npyrational_arrfuncs.setitem = npyrational_setitem;
1152 1
    npyrational_arrfuncs.copyswapn = npyrational_copyswapn;
1153 1
    npyrational_arrfuncs.copyswap = npyrational_copyswap;
1154 1
    npyrational_arrfuncs.compare = npyrational_compare;
1155 1
    npyrational_arrfuncs.argmin = npyrational_argmin;
1156 1
    npyrational_arrfuncs.argmax = npyrational_argmax;
1157 1
    npyrational_arrfuncs.dotfunc = npyrational_dot;
1158 1
    npyrational_arrfuncs.nonzero = npyrational_nonzero;
1159 1
    npyrational_arrfuncs.fill = npyrational_fill;
1160 1
    npyrational_arrfuncs.fillwithscalar = npyrational_fillwithscalar;
1161
    /* Left undefined: scanfunc, fromstr, sort, argsort */
1162 1
    Py_SET_TYPE(&npyrational_descr, &PyArrayDescr_Type);
1163 1
    npy_rational = PyArray_RegisterDataType(&npyrational_descr);
1164 1
    if (npy_rational<0) {
1165
        goto fail;
1166
    }
1167

1168
    /* Support dtype(rational) syntax */
1169 1
    if (PyDict_SetItemString(PyRational_Type.tp_dict, "dtype",
1170
                             (PyObject*)&npyrational_descr) < 0) {
1171
        goto fail;
1172
    }
1173

1174
    /* Register casts to and from rational */
1175
    #define REGISTER_CAST(From,To,from_descr,to_typenum,safe) { \
1176
            PyArray_Descr* from_descr_##From##_##To = (from_descr); \
1177
            if (PyArray_RegisterCastFunc(from_descr_##From##_##To, \
1178
                                         (to_typenum), \
1179
                                         npycast_##From##_##To) < 0) { \
1180
                goto fail; \
1181
            } \
1182
            if (safe && PyArray_RegisterCanCast(from_descr_##From##_##To, \
1183
                                                (to_typenum), \
1184
                                                NPY_NOSCALAR) < 0) { \
1185
                goto fail; \
1186
            } \
1187
        }
1188
    #define REGISTER_INT_CASTS(bits) \
1189
        REGISTER_CAST(npy_int##bits, rational, \
1190
                      PyArray_DescrFromType(NPY_INT##bits), npy_rational, 1) \
1191
        REGISTER_CAST(rational, npy_int##bits, &npyrational_descr, \
1192
                      NPY_INT##bits, 0)
1193 1
    REGISTER_INT_CASTS(8)
1194 1
    REGISTER_INT_CASTS(16)
1195 1
    REGISTER_INT_CASTS(32)
1196 1
    REGISTER_INT_CASTS(64)
1197 1
    REGISTER_CAST(rational,float,&npyrational_descr,NPY_FLOAT,0)
1198 1
    REGISTER_CAST(rational,double,&npyrational_descr,NPY_DOUBLE,1)
1199 1
    REGISTER_CAST(npy_bool,rational, PyArray_DescrFromType(NPY_BOOL),
1200
                  npy_rational,1)
1201 1
    REGISTER_CAST(rational,npy_bool,&npyrational_descr,NPY_BOOL,0)
1202

1203
    /* Register ufuncs */
1204
    #define REGISTER_UFUNC(name,...) { \
1205
        PyUFuncObject* ufunc = \
1206
            (PyUFuncObject*)PyObject_GetAttrString(numpy, #name); \
1207
        int _types[] = __VA_ARGS__; \
1208
        if (!ufunc) { \
1209
            goto fail; \
1210
        } \
1211
        if (sizeof(_types)/sizeof(int)!=ufunc->nargs) { \
1212
            PyErr_Format(PyExc_AssertionError, \
1213
                         "ufunc %s takes %d arguments, our loop takes %lu", \
1214
                         #name, ufunc->nargs, (unsigned long) \
1215
                         (sizeof(_types)/sizeof(int))); \
1216
            Py_DECREF(ufunc); \
1217
            goto fail; \
1218
        } \
1219
        if (PyUFunc_RegisterLoopForType((PyUFuncObject*)ufunc, npy_rational, \
1220
                rational_ufunc_##name, _types, 0) < 0) { \
1221
            Py_DECREF(ufunc); \
1222
            goto fail; \
1223
        } \
1224
        Py_DECREF(ufunc); \
1225
    }
1226
    #define REGISTER_UFUNC_BINARY_RATIONAL(name) \
1227
        REGISTER_UFUNC(name, {npy_rational, npy_rational, npy_rational})
1228
    #define REGISTER_UFUNC_BINARY_COMPARE(name) \
1229
        REGISTER_UFUNC(name, {npy_rational, npy_rational, NPY_BOOL})
1230
    #define REGISTER_UFUNC_UNARY(name) \
1231
        REGISTER_UFUNC(name, {npy_rational, npy_rational})
1232
    /* Binary */
1233 1
    REGISTER_UFUNC_BINARY_RATIONAL(add)
1234 1
    REGISTER_UFUNC_BINARY_RATIONAL(subtract)
1235 1
    REGISTER_UFUNC_BINARY_RATIONAL(multiply)
1236 1
    REGISTER_UFUNC_BINARY_RATIONAL(divide)
1237 1
    REGISTER_UFUNC_BINARY_RATIONAL(remainder)
1238 1
    REGISTER_UFUNC_BINARY_RATIONAL(true_divide)
1239 1
    REGISTER_UFUNC_BINARY_RATIONAL(floor_divide)
1240 1
    REGISTER_UFUNC_BINARY_RATIONAL(minimum)
1241 1
    REGISTER_UFUNC_BINARY_RATIONAL(maximum)
1242
    /* Comparisons */
1243 1
    REGISTER_UFUNC_BINARY_COMPARE(equal)
1244 1
    REGISTER_UFUNC_BINARY_COMPARE(not_equal)
1245 1
    REGISTER_UFUNC_BINARY_COMPARE(less)
1246 1
    REGISTER_UFUNC_BINARY_COMPARE(greater)
1247 1
    REGISTER_UFUNC_BINARY_COMPARE(less_equal)
1248 1
    REGISTER_UFUNC_BINARY_COMPARE(greater_equal)
1249
    /* Unary */
1250 1
    REGISTER_UFUNC_UNARY(negative)
1251 1
    REGISTER_UFUNC_UNARY(absolute)
1252 1
    REGISTER_UFUNC_UNARY(floor)
1253 1
    REGISTER_UFUNC_UNARY(ceil)
1254 1
    REGISTER_UFUNC_UNARY(trunc)
1255 1
    REGISTER_UFUNC_UNARY(rint)
1256 1
    REGISTER_UFUNC_UNARY(square)
1257 1
    REGISTER_UFUNC_UNARY(reciprocal)
1258 1
    REGISTER_UFUNC_UNARY(sign)
1259

1260
    /* Create module */
1261 1
    m = PyModule_Create(&moduledef);
1262

1263 1
    if (!m) {
1264
        goto fail;
1265
    }
1266

1267
    /* Add rational type */
1268 1
    Py_INCREF(&PyRational_Type);
1269 1
    PyModule_AddObject(m,"rational",(PyObject*)&PyRational_Type);
1270

1271
    /* Create matrix multiply generalized ufunc */
1272
    {
1273 1
        int types2[3] = {npy_rational,npy_rational,npy_rational};
1274 1
        PyObject* gufunc = PyUFunc_FromFuncAndDataAndSignature(0,0,0,0,2,1,
1275
            PyUFunc_None,(char*)"matrix_multiply",
1276
            (char*)"return result of multiplying two matrices of rationals",
1277
            0,"(m,n),(n,p)->(m,p)");
1278 1
        if (!gufunc) {
1279
            goto fail;
1280
        }
1281 1
        if (PyUFunc_RegisterLoopForType((PyUFuncObject*)gufunc, npy_rational,
1282
                rational_gufunc_matrix_multiply, types2, 0) < 0) {
1283
            goto fail;
1284
        }
1285 1
        PyModule_AddObject(m,"matrix_multiply",(PyObject*)gufunc);
1286
    }
1287

1288
    /* Create test ufunc with built in input types and rational output type */
1289
    {
1290 1
        int types3[3] = {NPY_INT64,NPY_INT64,npy_rational};
1291

1292 1
        PyObject* ufunc = PyUFunc_FromFuncAndData(0,0,0,0,2,1,
1293
                PyUFunc_None,(char*)"test_add",
1294
                (char*)"add two matrices of int64 and return rational matrix",0);
1295 1
        if (!ufunc) {
1296
            goto fail;
1297
        }
1298 1
        if (PyUFunc_RegisterLoopForType((PyUFuncObject*)ufunc, npy_rational,
1299
                rational_ufunc_test_add, types3, 0) < 0) {
1300
            goto fail;
1301
        }
1302 1
        PyModule_AddObject(m,"test_add",(PyObject*)ufunc);
1303
    }
1304

1305
    /* Create test ufunc with rational types using RegisterLoopForDescr */
1306
    {
1307 1
        PyObject* ufunc = PyUFunc_FromFuncAndData(0,0,0,0,2,1,
1308
                PyUFunc_None,(char*)"test_add_rationals",
1309
                (char*)"add two matrices of rationals and return rational matrix",0);
1310 1
        PyArray_Descr* types[3] = {&npyrational_descr,
1311
                                    &npyrational_descr,
1312
                                    &npyrational_descr};
1313

1314 1
        if (!ufunc) {
1315
            goto fail;
1316
        }
1317 1
        if (PyUFunc_RegisterLoopForDescr((PyUFuncObject*)ufunc, &npyrational_descr,
1318
                rational_ufunc_test_add_rationals, types, 0) < 0) {
1319
            goto fail;
1320
        }
1321 1
        PyModule_AddObject(m,"test_add_rationals",(PyObject*)ufunc);
1322
    }
1323

1324
    /* Create numerator and denominator ufuncs */
1325
    #define NEW_UNARY_UFUNC(name,type,doc) { \
1326
        int types[2] = {npy_rational,type}; \
1327
        PyObject* ufunc = PyUFunc_FromFuncAndData(0,0,0,0,1,1, \
1328
            PyUFunc_None,(char*)#name,(char*)doc,0); \
1329
        if (!ufunc) { \
1330
            goto fail; \
1331
        } \
1332
        if (PyUFunc_RegisterLoopForType((PyUFuncObject*)ufunc, \
1333
                npy_rational,rational_ufunc_##name,types,0)<0) { \
1334
            goto fail; \
1335
        } \
1336
        PyModule_AddObject(m,#name,(PyObject*)ufunc); \
1337
    }
1338 1
    NEW_UNARY_UFUNC(numerator,NPY_INT64,"rational number numerator");
1339 1
    NEW_UNARY_UFUNC(denominator,NPY_INT64,"rational number denominator");
1340

1341
    /* Create gcd and lcm ufuncs */
1342
    #define GCD_LCM_UFUNC(name,type,doc) { \
1343
        static const PyUFuncGenericFunction func[1] = {name##_ufunc}; \
1344
        static const char types[3] = {type,type,type}; \
1345
        static void* data[1] = {0}; \
1346
        PyObject* ufunc = PyUFunc_FromFuncAndData( \
1347
            (PyUFuncGenericFunction*)func, data,(char*)types, \
1348
            1,2,1,PyUFunc_One,(char*)#name,(char*)doc,0); \
1349
        if (!ufunc) { \
1350
            goto fail; \
1351
        } \
1352
        PyModule_AddObject(m,#name,(PyObject*)ufunc); \
1353
    }
1354 1
    GCD_LCM_UFUNC(gcd,NPY_INT64,"greatest common denominator of two integers");
1355 1
    GCD_LCM_UFUNC(lcm,NPY_INT64,"least common multiple of two integers");
1356

1357 1
    return m;
1358

1359 0
fail:
1360 0
    if (!PyErr_Occurred()) {
1361 0
        PyErr_SetString(PyExc_RuntimeError,
1362
                        "cannot load _rational_tests module.");
1363
    }
1364 0
    if (m) {
1365 0
        Py_DECREF(m);
1366
        m = NULL;
1367
    }
1368
    return m;
1369
}

Read our documentation on viewing source code .

Loading