Skip to content

Commit

Permalink
Merge pull request #30 from Mause/feature/jdbc-struct
Browse files Browse the repository at this point in the history
feat: struct, list, and map parameter support
  • Loading branch information
Mause authored Jun 12, 2024
2 parents b2393e0 + 456c7bf commit 329b9aa
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 73 deletions.
204 changes: 145 additions & 59 deletions src/jni/duckdb_java.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static jclass J_Float;
static jclass J_Double;
static jclass J_String;
static jclass J_Timestamp;
static jmethodID J_Timestamp_valueOf;
static jclass J_TimestampTZ;
static jclass J_Decimal;
static jclass J_ByteArray;
Expand Down Expand Up @@ -70,11 +71,22 @@ static jfieldID J_DuckVector_varlen;
static jclass J_DuckArray;
static jmethodID J_DuckArray_init;

static jclass J_Struct;
static jmethodID J_Struct_getSQLTypeName;
static jmethodID J_Struct_getAttributes;

static jclass J_Array;
static jmethodID J_Array_getBaseTypeName;
static jmethodID J_Array_getArray;

static jclass J_DuckStruct;
static jmethodID J_DuckStruct_init;

static jclass J_ByteBuffer;

static jclass J_DuckMap;
static jmethodID J_DuckMap_getSQLTypeName;

static jmethodID J_Map_entrySet;
static jmethodID J_Set_iterator;
static jmethodID J_Iterator_hasNext;
Expand All @@ -89,6 +101,7 @@ static jmethodID J_UUID_getLeastSignificantBits;
static jclass J_DuckDBDate;
static jmethodID J_DuckDBDate_getDaysSinceEpoch;

static jclass J_Object;
static jmethodID J_Object_toString;

static jclass J_DuckDBTime;
Expand Down Expand Up @@ -163,9 +176,12 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
tmpLocalRef = env->FindClass("java/lang/String");
J_String = (jclass)env->NewGlobalRef(tmpLocalRef);
env->DeleteLocalRef(tmpLocalRef);

tmpLocalRef = env->FindClass("org/duckdb/DuckDBTimestamp");
J_Timestamp = (jclass)env->NewGlobalRef(tmpLocalRef);
env->DeleteLocalRef(tmpLocalRef);
J_Timestamp_valueOf = env->GetStaticMethodID(J_Timestamp, "valueOf", "(Ljava/lang/Object;)Ljava/lang/Object;");

tmpLocalRef = env->FindClass("org/duckdb/DuckDBTimestampTZ");
J_TimestampTZ = (jclass)env->NewGlobalRef(tmpLocalRef);
env->DeleteLocalRef(tmpLocalRef);
Expand All @@ -183,6 +199,11 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
J_ByteArray = (jclass)env->NewGlobalRef(tmpLocalRef);
env->DeleteLocalRef(tmpLocalRef);

J_DuckMap = GetClassRef(env, "org/duckdb/user/DuckDBMap");
D_ASSERT(J_DuckMap);
J_DuckMap_getSQLTypeName = env->GetMethodID(J_DuckMap, "getSQLTypeName", "()Ljava/lang/String;");
D_ASSERT(J_DuckMap_getSQLTypeName);

tmpLocalRef = env->FindClass("java/util/Map");
J_Map_entrySet = env->GetMethodID(tmpLocalRef, "entrySet", "()Ljava/util/Set;");
env->DeleteLocalRef(tmpLocalRef);
Expand All @@ -209,13 +230,21 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
D_ASSERT(J_DuckArray_init);
env->DeleteLocalRef(tmpLocalRef);

tmpLocalRef = env->FindClass("org/duckdb/DuckDBStruct");
D_ASSERT(tmpLocalRef);
J_DuckStruct = (jclass)env->NewGlobalRef(tmpLocalRef);
J_DuckStruct = GetClassRef(env, "org/duckdb/DuckDBStruct");
J_DuckStruct_init =
env->GetMethodID(J_DuckStruct, "<init>", "([Ljava/lang/String;[Lorg/duckdb/DuckDBVector;ILjava/lang/String;)V");
D_ASSERT(J_DuckStruct_init);
env->DeleteLocalRef(tmpLocalRef);

J_Struct = GetClassRef(env, "java/sql/Struct");
J_Struct_getSQLTypeName = env->GetMethodID(J_Struct, "getSQLTypeName", "()Ljava/lang/String;");
J_Struct_getAttributes = env->GetMethodID(J_Struct, "getAttributes", "()[Ljava/lang/Object;");

J_Array = GetClassRef(env, "java/sql/Array");
J_Array_getArray = env->GetMethodID(J_Array, "getArray", "()Ljava/lang/Object;");
J_Array_getBaseTypeName = env->GetMethodID(J_Array, "getBaseTypeName", "()Ljava/lang/String;");

J_Object = GetClassRef(env, "java/lang/Object");
J_Object_toString = env->GetMethodID(J_Object, "toString", "()Ljava/lang/String;");

tmpLocalRef = env->FindClass("java/util/Map$Entry");
J_Entry_getKey = env->GetMethodID(tmpLocalRef, "getKey", "()Ljava/lang/Object;");
Expand Down Expand Up @@ -559,11 +588,120 @@ struct ResultHolder {
duckdb::unique_ptr<DataChunk> chunk;
};

Value ToValue(JNIEnv *env, jobject param, duckdb::shared_ptr<ClientContext> context) {
param = env->CallStaticObjectMethod(J_Timestamp, J_Timestamp_valueOf, param);

if (param == nullptr) {
return (Value());
} else if (env->IsInstanceOf(param, J_Bool)) {
return (Value::BOOLEAN(env->CallBooleanMethod(param, J_Bool_booleanValue)));
} else if (env->IsInstanceOf(param, J_Byte)) {
return (Value::TINYINT(env->CallByteMethod(param, J_Byte_byteValue)));
} else if (env->IsInstanceOf(param, J_Short)) {
return (Value::SMALLINT(env->CallShortMethod(param, J_Short_shortValue)));
} else if (env->IsInstanceOf(param, J_Int)) {
return (Value::INTEGER(env->CallIntMethod(param, J_Int_intValue)));
} else if (env->IsInstanceOf(param, J_Long)) {
return (Value::BIGINT(env->CallLongMethod(param, J_Long_longValue)));
} else if (env->IsInstanceOf(param, J_TimestampTZ)) { // Check for subclass before superclass!
return (
Value::TIMESTAMPTZ((timestamp_t)env->CallLongMethod(param, J_TimestampTZ_getMicrosEpoch)));
} else if (env->IsInstanceOf(param, J_DuckDBDate)) {
return (
Value::DATE((date_t)env->CallLongMethod(param, J_DuckDBDate_getDaysSinceEpoch)));

} else if (env->IsInstanceOf(param, J_DuckDBTime)) {
return (Value::TIME((dtime_t)env->CallLongMethod(param, J_Timestamp_getMicrosEpoch)));
} else if (env->IsInstanceOf(param, J_Timestamp)) {
return (
Value::TIMESTAMP((timestamp_t)env->CallLongMethod(param, J_Timestamp_getMicrosEpoch)));
} else if (env->IsInstanceOf(param, J_Float)) {
return (Value::FLOAT(env->CallFloatMethod(param, J_Float_floatValue)));
} else if (env->IsInstanceOf(param, J_Double)) {
return (Value::DOUBLE(env->CallDoubleMethod(param, J_Double_doubleValue)));
} else if (env->IsInstanceOf(param, J_Decimal)) {
Value val = create_value_from_bigdecimal(env, param);
return (val);
} else if (env->IsInstanceOf(param, J_String)) {
auto param_string = jstring_to_string(env, (jstring)param);
return (Value(param_string));
} else if (env->IsInstanceOf(param, J_ByteArray)) {
return (Value::BLOB_RAW(byte_array_to_string(env, (jbyteArray)param)));
} else if (env->IsInstanceOf(param, J_UUID)) {
auto most_significant = (jlong)env->CallObjectMethod(param, J_UUID_getMostSignificantBits);
auto least_significant = (jlong)env->CallObjectMethod(param, J_UUID_getLeastSignificantBits);
return (Value::UUID(hugeint_t(most_significant, least_significant)));
} else if (env->IsInstanceOf(param, J_DuckMap)) {
auto typeName = jstring_to_string(env, (jstring)env->CallObjectMethod(param, J_DuckMap_getSQLTypeName));

LogicalType type;
context->RunFunctionInTransaction([&]() { type = TransformStringToLogicalType(typeName, *context); });

auto entrySet = env->CallObjectMethod(param, J_Map_entrySet);
auto iterator = env->CallObjectMethod(entrySet, J_Set_iterator);
duckdb::vector<Value> entries;
while (env->CallBooleanMethod(iterator, J_Iterator_hasNext)) {
auto entry = env->CallObjectMethod(iterator, J_Iterator_next);

auto key = env->CallObjectMethod(entry, J_Entry_getKey);
auto value = env->CallObjectMethod(entry, J_Entry_getValue);
D_ASSERT(key);
D_ASSERT(value);

entries.push_back(Value::STRUCT({{"key", ToValue(env, key, context)}, {"value", ToValue(env, value, context)}}));
}

return (Value::MAP(ListType::GetChildType(type), entries));

} else if (env->IsInstanceOf(param, J_Struct)) {
auto typeName = jstring_to_string(env, (jstring)env->CallObjectMethod(param, J_Struct_getSQLTypeName));

LogicalType type;
context->RunFunctionInTransaction([&]() { type = TransformStringToLogicalType(typeName, *context); });

auto jvalues = (jobjectArray)env->CallObjectMethod(param, J_Struct_getAttributes);

int size = env->GetArrayLength(jvalues);

child_list_t<Value> values;

for (int i = 0; i < size; i++) {
auto name = StructType::GetChildName(type, i);

auto value = env->GetObjectArrayElement(jvalues, i);

values.emplace_back(name, ToValue(env, value, context));
}

return (Value::STRUCT(std::move(values)));
} else if (env->IsInstanceOf(param, J_Array)) {
auto typeName = jstring_to_string(env, (jstring)env->CallObjectMethod(param, J_Array_getBaseTypeName));
auto jvalues = (jobjectArray)env->CallObjectMethod(param, J_Array_getArray);
int size = env->GetArrayLength(jvalues);

LogicalType type;
context->RunFunctionInTransaction([&]() { type = TransformStringToLogicalType(typeName, *context); });

duckdb::vector<Value> values;
for (int i = 0; i < size; i++) {
auto value = env->GetObjectArrayElement(jvalues, i);

values.emplace_back(ToValue(env, value, context));
}

return (Value::LIST(type, values));

} else {
throw InvalidInputException("Unsupported parameter type");
}
}

jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectArray params) {
auto stmt_ref = (StatementHolder *)env->GetDirectBufferAddress(stmt_ref_buf);
if (!stmt_ref) {
throw InvalidInputException("Invalid statement");
}

auto res_ref = make_uniq<ResultHolder>();
duckdb::vector<Value> duckdb_params;

Expand All @@ -572,64 +710,12 @@ jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectA
throw InvalidInputException("Parameter count mismatch");
}

auto &context = stmt_ref->stmt->context;

if (param_len > 0) {
for (idx_t i = 0; i < param_len; i++) {
auto param = env->GetObjectArrayElement(params, i);
if (param == nullptr) {
duckdb_params.push_back(Value());
continue;
} else if (env->IsInstanceOf(param, J_Bool)) {
duckdb_params.push_back(Value::BOOLEAN(env->CallBooleanMethod(param, J_Bool_booleanValue)));
continue;
} else if (env->IsInstanceOf(param, J_Byte)) {
duckdb_params.push_back(Value::TINYINT(env->CallByteMethod(param, J_Byte_byteValue)));
continue;
} else if (env->IsInstanceOf(param, J_Short)) {
duckdb_params.push_back(Value::SMALLINT(env->CallShortMethod(param, J_Short_shortValue)));
continue;
} else if (env->IsInstanceOf(param, J_Int)) {
duckdb_params.push_back(Value::INTEGER(env->CallIntMethod(param, J_Int_intValue)));
continue;
} else if (env->IsInstanceOf(param, J_Long)) {
duckdb_params.push_back(Value::BIGINT(env->CallLongMethod(param, J_Long_longValue)));
continue;
} else if (env->IsInstanceOf(param, J_TimestampTZ)) { // Check for subclass before superclass!
duckdb_params.push_back(
Value::TIMESTAMPTZ((timestamp_t)env->CallLongMethod(param, J_TimestampTZ_getMicrosEpoch)));
continue;
} else if (env->IsInstanceOf(param, J_DuckDBDate)) {
duckdb_params.push_back(
Value::DATE((date_t)env->CallLongMethod(param, J_DuckDBDate_getDaysSinceEpoch)));

} else if (env->IsInstanceOf(param, J_DuckDBTime)) {
duckdb_params.push_back(Value::TIME((dtime_t)env->CallLongMethod(param, J_Timestamp_getMicrosEpoch)));

} else if (env->IsInstanceOf(param, J_Timestamp)) {
duckdb_params.push_back(
Value::TIMESTAMP((timestamp_t)env->CallLongMethod(param, J_Timestamp_getMicrosEpoch)));
continue;
} else if (env->IsInstanceOf(param, J_Float)) {
duckdb_params.push_back(Value::FLOAT(env->CallFloatMethod(param, J_Float_floatValue)));
continue;
} else if (env->IsInstanceOf(param, J_Double)) {
duckdb_params.push_back(Value::DOUBLE(env->CallDoubleMethod(param, J_Double_doubleValue)));
continue;
} else if (env->IsInstanceOf(param, J_Decimal)) {
Value val = create_value_from_bigdecimal(env, param);
duckdb_params.push_back(val);
continue;
} else if (env->IsInstanceOf(param, J_String)) {
auto param_string = jstring_to_string(env, (jstring)param);
duckdb_params.push_back(Value(param_string));
} else if (env->IsInstanceOf(param, J_ByteArray)) {
duckdb_params.push_back(Value::BLOB_RAW(byte_array_to_string(env, (jbyteArray)param)));
} else if (env->IsInstanceOf(param, J_UUID)) {
auto most_significant = (jlong)env->CallObjectMethod(param, J_UUID_getMostSignificantBits);
auto least_significant = (jlong)env->CallObjectMethod(param, J_UUID_getLeastSignificantBits);
duckdb_params.push_back(Value::UUID(hugeint_t(most_significant, least_significant)));
} else {
throw InvalidInputException("Unsupported parameter type");
}
duckdb_params.push_back(ToValue(env, param, context));
}
}

Expand Down
12 changes: 10 additions & 2 deletions src/main/java/org/duckdb/DuckDBConnection.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package org.duckdb;

import org.duckdb.user.DuckDBMap;
import org.duckdb.user.DuckDBUserArray;
import org.duckdb.user.DuckDBUserStruct;

import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -327,11 +331,15 @@ public Properties getClientInfo() throws SQLException {
}

public Array createArrayOf(String typeName, Object[] elements) throws SQLException {
throw new SQLFeatureNotSupportedException("createArrayOf");
return new DuckDBUserArray(typeName, elements);
}

public Struct createStruct(String typeName, Object[] attributes) throws SQLException {
throw new SQLFeatureNotSupportedException("createStruct");
return new DuckDBUserStruct(typeName, attributes);
}

public <K, V> Map<K, V> createMap(String typeName, Map<K, V> map) {
return new DuckDBMap<>(typeName, map);
}

public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {
Expand Down
12 changes: 0 additions & 12 deletions src/main/java/org/duckdb/DuckDBPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,6 @@ public void setObject(int parameterIndex, Object x) throws SQLException {
if (params.length == 0) {
params = new Object[getParameterMetaData().getParameterCount()];
}
// Change sql.Timestamp to DuckDBTimestamp
if (x instanceof Timestamp) {
x = new DuckDBTimestamp((Timestamp) x);
} else if (x instanceof LocalDateTime) {
x = new DuckDBTimestamp((LocalDateTime) x);
} else if (x instanceof OffsetDateTime) {
x = new DuckDBTimestampTZ((OffsetDateTime) x);
} else if (x instanceof Date) {
x = new DuckDBDate((Date) x);
} else if (x instanceof Time) {
x = new DuckDBTime((Time) x);
}
params[parameterIndex - 1] = x;
}

Expand Down
19 changes: 19 additions & 0 deletions src/main/java/org/duckdb/DuckDBTimestamp.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.duckdb;

import java.sql.Timestamp;
import java.sql.Time;
import java.sql.Date;
import java.time.ZoneOffset;
import java.time.Instant;
import java.time.LocalDateTime;
Expand Down Expand Up @@ -99,6 +101,23 @@ public static long localDateTime2Micros(LocalDateTime localDateTime) {
return DuckDBTimestamp.RefLocalDateTime.until(localDateTime, ChronoUnit.MICROS);
}

// TODO: move this to C++ side
public static Object valueOf(Object x) {
// Change sql.Timestamp to DuckDBTimestamp
if (x instanceof Timestamp) {
x = new DuckDBTimestamp((Timestamp) x);
} else if (x instanceof LocalDateTime) {
x = new DuckDBTimestamp((LocalDateTime) x);
} else if (x instanceof OffsetDateTime) {
x = new DuckDBTimestampTZ((OffsetDateTime) x);
} else if (x instanceof Date) {
x = new DuckDBDate((Date) x);
} else if (x instanceof Time) {
x = new DuckDBTime((Time) x);
}
return x;
}

public Timestamp toSqlTimestamp() {
return Timestamp.valueOf(this.toLocalDateTime());
}
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/org/duckdb/user/DuckDBMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.duckdb.user;

import java.util.HashMap;
import java.util.Map;

public class DuckDBMap<K, V> extends HashMap<K, V> {
private final String typeName;

public DuckDBMap(String typeName, Map<K, V> map) {
super(map);
this.typeName = typeName;
}

public String getSQLTypeName() {
return typeName;
}
}
Loading

0 comments on commit 329b9aa

Please sign in to comment.