diff --git a/shell/android-studio/flycast/src/main/jni/src/jni_util.h b/shell/android-studio/flycast/src/main/jni/src/jni_util.h index 30f7d2436..2fc0b992a 100644 --- a/shell/android-studio/flycast/src/main/jni/src/jni_util.h +++ b/shell/android-studio/flycast/src/main/jni/src/jni_util.h @@ -249,99 +249,143 @@ class ObjectArray : public Array } }; -class ByteArray : public Array +class ByteArray; +class IntArray; +class ShortArray; +class BooleanArray; + +namespace detail { -public: - using jtype = jbyteArray; +// Use a traits type and specializations to define types needed in the base CRTP template. +template struct JniArrayTraits; +template <> struct JniArrayTraits { + using jtype = jbyteArray; + using ctype = u8; + using vtype = ctype; + static constexpr char const * JNISignature = "[B"; +}; +template <> struct JniArrayTraits { + using jtype = jintArray; + using ctype = int; + using vtype = ctype; + static constexpr char const * JNISignature = "[I"; +}; +template <> struct JniArrayTraits { + using jtype = jshortArray; + using ctype = short; + using vtype = ctype; + static constexpr char const * JNISignature = "[S"; +}; +template <> struct JniArrayTraits { + using jtype = jbooleanArray; + using ctype = bool; + using vtype = u8; // avoid std::vector abomination + static constexpr char const * JNISignature = "[Z"; +}; +} - ByteArray(jobject array = nullptr, bool ownRef = true, bool globalRef = false) : Array(array, ownRef, globalRef) { } - ByteArray(ByteArray &&other) : Array(std::move(other)) {} - explicit ByteArray(size_t size) : ByteArray() { - object = env()->NewByteArray(size); - } +template +class PrimitiveArray : public Array +{ + using ctype = typename detail::JniArrayTraits::ctype; + using vtype = typename detail::JniArrayTraits::vtype; - ByteArray& operator=(const ByteArray& other) { - return (ByteArray&)Object::operator=(other); - } +public: + using jtype = typename detail::JniArrayTraits::jtype; - operator jbyteArray() const { return (jbyteArray)object; } + PrimitiveArray(jobject array = nullptr, bool ownRef = true, bool globalRef = false) : Array(array, ownRef, globalRef) { } + PrimitiveArray(PrimitiveArray &&other) : Array(std::move(other)) {} - void getData(u8 *dst, size_t first = 0, size_t len = 0) const { + operator jtype() const { return static_cast(object); } + + void getData(ctype *dst, size_t first = 0, size_t len = 0) const + { if (len == 0) len = size(); if (len != 0) - env()->GetByteArrayRegion((jbyteArray)object, first, len, (jbyte *)dst); + static_cast(this)->getJavaArrayRegion(object, first, len, dst); } - void setData(const u8 *src, size_t first = 0, size_t len = 0) { + void setData(const ctype *src, size_t first = 0, size_t len = 0) + { if (len == 0) len = size(); - env()->SetByteArrayRegion((jbyteArray)object, first, len, (const jbyte *)src); + static_cast(this)->setJavaArrayRegion(object, first, len, src); } - operator std::vector() const + operator std::vector() const { - std::vector v; + std::vector v; v.resize(size()); - getData(v.data()); + getData(static_cast(v.data())); return v; } static Class getClass() { - return Class(env()->FindClass("[B")); + return Class(env()->FindClass(detail::JniArrayTraits::JNISignature)); } }; -class IntArray : public Array +class ByteArray : public PrimitiveArray { + using super = PrimitiveArray; + friend super; + public: - using jtype = jintArray; + ByteArray(jobject array = nullptr, bool ownRef = true, bool globalRef = false) : super(array, ownRef, globalRef) { } + ByteArray(ByteArray &&other) : super(std::move(other)) {} + explicit ByteArray(size_t size) : ByteArray() { + object = env()->NewByteArray(size); + } - IntArray(jobject array = nullptr, bool ownRef = true, bool globalRef = false) : Array(array, ownRef, globalRef) { } - IntArray(IntArray &&other) : Array(std::move(other)) {} - explicit IntArray(size_t size) : IntArray() { - object = env()->NewIntArray(size); + ByteArray& operator=(const ByteArray& other) { + return (ByteArray&)Object::operator=(other); } - IntArray& operator=(const IntArray& other) { - return (IntArray&)Object::operator=(other); +protected: + void getJavaArrayRegion(jobject object, size_t first, size_t len, u8 *dst) const { + env()->GetByteArrayRegion((jbyteArray)object, first, len, (jbyte *)dst); } - operator jintArray() const { return (jintArray)object; } + void setJavaArrayRegion(jobject object, size_t first, size_t len, const u8 *dst) { + env()->SetByteArrayRegion((jbyteArray)object, first, len, (const jbyte *)dst); + } +}; - void getData(int *dst, size_t first = 0, size_t len = 0) const { - if (len == 0) - len = size(); - if (len != 0) - env()->GetIntArrayRegion((jintArray)object, first, len, (jint *)dst); +class IntArray : public PrimitiveArray +{ + using super = PrimitiveArray; + friend super; + +public: + IntArray(jobject array = nullptr, bool ownRef = true, bool globalRef = false) : super(array, ownRef, globalRef) { } + IntArray(IntArray &&other) : super(std::move(other)) {} + explicit IntArray(size_t size) : IntArray() { + object = env()->NewIntArray(size); } - void setData(const int *src, size_t first = 0, size_t len = 0) { - if (len == 0) - len = size(); - env()->SetIntArrayRegion((jintArray)object, first, len, (const jint *)src); + IntArray& operator=(const IntArray& other) { + return (IntArray&)Object::operator=(other); } - operator std::vector() const - { - std::vector v; - v.resize(size()); - getData(v.data()); - return v; +protected: + void getJavaArrayRegion(jobject object, size_t first, size_t len, int *dst) const { + env()->GetIntArrayRegion((jintArray)object, first, len, (jint *)dst); } - static Class getClass() { - return Class(env()->FindClass("[I")); + void setJavaArrayRegion(jobject object, size_t first, size_t len, const int *dst) { + env()->SetIntArrayRegion((jintArray)object, first, len, (const jint *)dst); } }; -class ShortArray : public Array +class ShortArray : public PrimitiveArray { -public: - using jtype = jshortArray; + using super = PrimitiveArray; + friend super; - ShortArray(jobject array = nullptr, bool ownRef = true, bool globalRef = false) : Array(array, ownRef, globalRef) { } - ShortArray(ShortArray &&other) : Array(std::move(other)) {} +public: + ShortArray(jobject array = nullptr, bool ownRef = true, bool globalRef = false) : super(array, ownRef, globalRef) { } + ShortArray(ShortArray &&other) : super(std::move(other)) {} explicit ShortArray(size_t size) : ShortArray() { object = env()->NewShortArray(size); } @@ -350,23 +394,39 @@ class ShortArray : public Array return (ShortArray&)Object::operator=(other); } - operator jshortArray() const { return (jshortArray)object; } +protected: + void getJavaArrayRegion(jobject object, size_t first, size_t len, short *dst) const { + env()->GetShortArrayRegion((jshortArray)object, first, len, (jshort *)dst); + } - void getData(short *dst, size_t first = 0, size_t len = 0) { - if (len == 0) - len = size(); - if (len != 0) - env()->GetShortArrayRegion((jshortArray)object, first, len, (jshort *)dst); + void setJavaArrayRegion(jobject object, size_t first, size_t len, const short *dst) { + env()->SetShortArrayRegion((jshortArray)object, first, len, (const jshort *)dst); } +}; - void setData(const short *src, size_t first = 0, size_t len = 0) { - if (len == 0) - len = size(); - env()->SetShortArrayRegion((jshortArray)object, first, len, (const jshort *)src); +class BooleanArray : public PrimitiveArray +{ + using super = PrimitiveArray; + friend super; + +public: + BooleanArray(jobject array = nullptr, bool ownRef = true, bool globalRef = false) : super(array, ownRef, globalRef) { } + BooleanArray(BooleanArray &&other) : super(std::move(other)) {} + explicit BooleanArray(size_t size) : BooleanArray() { + object = env()->NewBooleanArray(size); } - static Class getClass() { - return Class(env()->FindClass("[S")); + BooleanArray& operator=(const BooleanArray& other) { + return (BooleanArray&)Object::operator=(other); + } + +protected: + void getJavaArrayRegion(jobject object, size_t first, size_t len, bool *dst) const { + env()->GetBooleanArrayRegion((jbooleanArray)object, first, len, (jboolean *)dst); + } + + void setJavaArrayRegion(jobject object, size_t first, size_t len, const bool *dst) { + env()->SetBooleanArrayRegion((jbooleanArray)object, first, len, (const jboolean *)dst); } };