Skip to content

Commit

Permalink
Use simpler version of numpy 2 compat
Browse files Browse the repository at this point in the history
Co-authored-by: Peter Hawkins <[email protected]>
  • Loading branch information
moble and hawkinsp committed Sep 30, 2024
1 parent 907cde4 commit a0e2218
Showing 1 changed file with 95 additions and 12 deletions.
107 changes: 95 additions & 12 deletions src/numpy_quaternion.c
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
// Copyright (c) 2024, Michael Boyle
// See LICENSE file for details: <https://github.com/moble/quaternion/blob/main/LICENSE>

#define PY_ARRAY_UNIQUE_SYMBOL NumpyQuaternion
// #define NPY_NO_DEPRECATED_API NPY_API_VERSION
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#define NPY_NO_DEPRECATED_API NPY_API_VERSION

#include <Python.h>
#include <numpy/ndarrayobject.h>
#include <numpy/arrayobject.h>
#include <numpy/npy_math.h>
#include <numpy/ufuncobject.h>
#include "structmember.h"
#include "quaternion.h"

// Provide compatibility with numpy 1 and 2
#include "npy_2_compat.h"
#include "quaternion.h"

// Numpy 1.19 changed UFuncGenericFunction to use const `dimensions` and `steps` pointers.
// Supposedly, this should only generate warnings, but this now causes errors in CI, so
Expand All @@ -25,6 +22,15 @@ typedef npy_intp NPY_INTP_CONST;
typedef npy_intp const NPY_INTP_CONST;
#endif

#if NPY_ABI_VERSION < 0x02000000
#define PyArray_DescrProto PyArray_Descr
#define PyDataType_ELSIZE(d) ((d)->elsize)
#define PyDataType_GetArrFuncs(d) ((d)->f)
#endif

// The following definitions, along with `#define NPY_PY3K 1`, can
// also be found in the header <numpy/npy_3kcompat.h>.
#if PY_MAJOR_VERSION >= 3
#define PyUString_FromString PyUnicode_FromString
static NPY_INLINE int PyInt_Check(PyObject *op) {
int overflow = 0;
Expand All @@ -35,6 +41,16 @@ static NPY_INLINE int PyInt_Check(PyObject *op) {
return (overflow == 0);
}
#define PyInt_AsLong PyLong_AsLong
#else
#define PyUString_FromString PyString_FromString
#endif

// This macro was introduced in python 3.4.2
#ifndef Py_RETURN_NOTIMPLEMENTED
/* Macro for returning Py_NotImplemented from a function */
#define Py_RETURN_NOTIMPLEMENTED \
return Py_INCREF(Py_NotImplemented), Py_NotImplemented
#endif


// The basic python object holding a quaternion
Expand All @@ -47,10 +63,10 @@ static PyTypeObject PyQuaternion_Type;

// This is the crucial feature that will make a quaternion into a
// built-in numpy data type. We will describe its features below.
PyArray_DescrProto* quaternion_descr;
PyArray_Descr* quaternion_descr;

PyArray_DescrProto quaternion_proto = {PyObject_HEAD_INIT(NULL)};

static NPY_INLINE int
PyQuaternion_Check(PyObject* object) {
return PyObject_IsInstance(object,(PyObject*)&PyQuaternion_Type);
Expand Down Expand Up @@ -534,11 +550,19 @@ static int pyquaternion_num_nonzero(PyObject* a) {
}
CANNOT_CONVERT(int)
CANNOT_CONVERT(float)
#if PY_MAJOR_VERSION < 3
CANNOT_CONVERT(long)
CANNOT_CONVERT(oct)
CANNOT_CONVERT(hex)
#endif

static PyNumberMethods pyquaternion_as_number = {
pyquaternion_add, // nb_add
pyquaternion_subtract, // nb_subtract
pyquaternion_multiply, // nb_multiply
#if PY_MAJOR_VERSION < 3
pyquaternion_divide, // nb_divide
#endif
0, // nb_remainder
0, // nb_divmod
pyquaternion_num_power, // nb_power
Expand All @@ -552,12 +576,26 @@ static PyNumberMethods pyquaternion_as_number = {
0, // nb_and
0, // nb_xor
0, // nb_or
#if PY_MAJOR_VERSION < 3
0, // nb_coerce
#endif
pyquaternion_convert_int, // nb_int
#if PY_MAJOR_VERSION >= 3
0, // nb_reserved
#else
pyquaternion_convert_long, // nb_long
#endif
pyquaternion_convert_float, // nb_float
#if PY_MAJOR_VERSION < 3
pyquaternion_convert_oct, // nb_oct
pyquaternion_convert_hex, // nb_hex
#endif
0, // nb_inplace_add
0, // nb_inplace_subtract
0, // nb_inplace_multiply
#if PY_MAJOR_VERSION < 3
0, // nb_inplace_divide
#endif
0, // nb_inplace_remainder
0, // nb_inplace_power
0, // nb_inplace_lshift
Expand All @@ -570,8 +608,12 @@ static PyNumberMethods pyquaternion_as_number = {
0, // nb_inplace_floor_divide
0, // nb_inplace_true_divide
0, // nb_index
#if PY_MAJOR_VERSION >= 3
#if PY_MINOR_VERSION >= 5
0, // nb_matrix_multiply
0, // nb_inplace_matrix_multiply
#endif
#endif
};


Expand Down Expand Up @@ -809,15 +851,24 @@ pyquaternion_str(PyObject *o)
// Note that many of the slots below will be filled later, after the
// corresponding functions are defined.
static PyTypeObject PyQuaternion_Type = {
#if PY_MAJOR_VERSION >= 3
PyVarObject_HEAD_INIT(NULL, 0)
#else
PyObject_HEAD_INIT(NULL)
0, // ob_size
#endif
"quaternion.quaternion", // tp_name
sizeof(PyQuaternion), // tp_basicsize
0, // tp_itemsize
0, // tp_dealloc
0, // tp_print
0, // tp_getattr
0, // tp_setattr
#if PY_MAJOR_VERSION >= 3
0, // tp_reserved
#else
0, // tp_compare
#endif
pyquaternion_repr, // tp_repr
&pyquaternion_as_number, // tp_as_number
0, // tp_as_sequence
Expand All @@ -828,7 +879,11 @@ static PyTypeObject PyQuaternion_Type = {
0, // tp_getattro
0, // tp_setattro
0, // tp_as_buffer
#if PY_MAJOR_VERSION >= 3
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
#else
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_CHECKTYPES, // tp_flags
#endif
"Floating-point quaternion numbers", // tp_doc
0, // tp_traverse
0, // tp_clear
Expand All @@ -855,8 +910,12 @@ static PyTypeObject PyQuaternion_Type = {
0, // tp_subclasses
0, // tp_weaklist
0, // tp_del
#if PY_VERSION_HEX >= 0x02060000
0, // tp_version_tag
#endif
#if PY_VERSION_HEX >= 0x030400a1
0, // tp_finalize
#endif
};

// Functions implementing internal features. Not all of these function
Expand Down Expand Up @@ -1339,6 +1398,8 @@ int quaternion_alignment = offsetof(align_test, q);
/////////////////////////////////////////////////////////////////


#if PY_MAJOR_VERSION >= 3

static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"numpy_quaternion",
Expand All @@ -1356,6 +1417,15 @@ static struct PyModuleDef moduledef = {
// This is the initialization function that does the setup
PyMODINIT_FUNC PyInit_numpy_quaternion(void) {

#else

#define INITERROR return

// This is the initialization function that does the setup
PyMODINIT_FUNC initnumpy_quaternion(void) {

#endif

PyObject *module;
PyObject *tmp_ufunc;
PyObject *slerp_evaluate_ufunc;
Expand All @@ -1367,15 +1437,24 @@ PyMODINIT_FUNC PyInit_numpy_quaternion(void) {
PyObject* numpy_dict;

// Initialize a (for now, empty) module
#if PY_MAJOR_VERSION >= 3
module = PyModule_Create(&moduledef);
#else
module = Py_InitModule("numpy_quaternion", QuaternionMethods);
#endif

if(module==NULL) {
INITERROR;
}

// Initialize numpy
if (PyArray_ImportNumPyAPI() < 0) {
return NULL;
import_array();
if (PyErr_Occurred()) {
INITERROR;
}
import_umath();
if (PyErr_Occurred()) {
INITERROR;
}
numpy = PyImport_ImportModule("numpy");
if (!numpy) {
Expand Down Expand Up @@ -1641,5 +1720,9 @@ PyMODINIT_FUNC PyInit_numpy_quaternion(void) {
PyModule_AddObject(module, "quaternion", (PyObject *)&PyQuaternion_Type);


#if PY_MAJOR_VERSION >= 3
return module;
}
#else
return;
#endif
}

0 comments on commit a0e2218

Please sign in to comment.