diff --git a/bfloat16.cc b/bfloat16.cc index 6b20eb2b7..bd766e5fe 100644 --- a/bfloat16.cc +++ b/bfloat16.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include #include #ifdef DEBUG_CALLS #include @@ -1745,7 +1746,13 @@ namespace greenwaves NPyBfloat16_ArrFuncs.argmax = NPyBfloat16_ArgMaxFunc; NPyBfloat16_ArrFuncs.argmin = NPyBfloat16_ArgMinFunc; + // Py_TYPE has been deprecated since Python 3.9 + #if PY_MINOR_VERSION < 9 Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type; + #else + Py_SET_TYPE(&NPyBfloat16_Descr, &PyArrayDescr_Type); + #endif + npy_bfloat16 = PyArray_RegisterDataType(&NPyBfloat16_Descr); bfloat16_type_ptr = &bfloat16_type; if (npy_bfloat16 < 0) diff --git a/setup.py b/setup.py index e234d4fec..1255b9954 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,8 @@ def build_extensions(self): module1 = Extension(PACKAGE_NAME, sources=['bfloat16.cc'], - include_dirs=[np.get_include()]) + include_dirs=[np.get_include()], + extra_compile_args=['-std=c++11']) setup(name=PACKAGE_NAME, version='1.1',