Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: struct, list, and map parameter support #30

Merged
merged 17 commits into from
Jun 12, 2024
204 changes: 145 additions & 59 deletions src/jni/duckdb_java.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static jclass J_Float;
static jclass J_Double;
static jclass J_String;
static jclass J_Timestamp;
static jmethodID J_Timestamp_valueOf;
static jclass J_TimestampTZ;
static jclass J_Decimal;
static jclass J_ByteArray;
Expand Down Expand Up @@ -70,11 +71,22 @@ static jfieldID J_DuckVector_varlen;
static jclass J_DuckArray;
static jmethodID J_DuckArray_init;

static jclass J_Struct;
static jmethodID J_Struct_getSQLTypeName;
static jmethodID J_Struct_getAttributes;

static jclass J_Array;
static jmethodID J_Array_getBaseTypeName;
static jmethodID J_Array_getArray;

static jclass J_DuckStruct;
static jmethodID J_DuckStruct_init;

static jclass J_ByteBuffer;

static jclass J_DuckMap;
static jmethodID J_DuckMap_getSQLTypeName;

static jmethodID J_Map_entrySet;
static jmethodID J_Set_iterator;
static jmethodID J_Iterator_hasNext;
Expand All @@ -89,6 +101,7 @@ static jmethodID J_UUID_getLeastSignificantBits;
static jclass J_DuckDBDate;
static jmethodID J_DuckDBDate_getDaysSinceEpoch;

static jclass J_Object;
static jmethodID J_Object_toString;

static jclass J_DuckDBTime;
Expand Down Expand Up @@ -163,9 +176,12 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
tmpLocalRef = env->FindClass("java/lang/String");
J_String = (jclass)env->NewGlobalRef(tmpLocalRef);
env->DeleteLocalRef(tmpLocalRef);

tmpLocalRef = env->FindClass("org/duckdb/DuckDBTimestamp");
J_Timestamp = (jclass)env->NewGlobalRef(tmpLocalRef);
env->DeleteLocalRef(tmpLocalRef);
J_Timestamp_valueOf = env->GetStaticMethodID(J_Timestamp, "valueOf", "(Ljava/lang/Object;)Ljava/lang/Object;");

tmpLocalRef = env->FindClass("org/duckdb/DuckDBTimestampTZ");
J_TimestampTZ = (jclass)env->NewGlobalRef(tmpLocalRef);
env->DeleteLocalRef(tmpLocalRef);
Expand All @@ -183,6 +199,11 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
J_ByteArray = (jclass)env->NewGlobalRef(tmpLocalRef);
env->DeleteLocalRef(tmpLocalRef);

J_DuckMap = GetClassRef(env, "org/duckdb/user/DuckDBMap");
D_ASSERT(J_DuckMap);
J_DuckMap_getSQLTypeName = env->GetMethodID(J_DuckMap, "getSQLTypeName", "()Ljava/lang/String;");
D_ASSERT(J_DuckMap_getSQLTypeName);

tmpLocalRef = env->FindClass("java/util/Map");
J_Map_entrySet = env->GetMethodID(tmpLocalRef, "entrySet", "()Ljava/util/Set;");
env->DeleteLocalRef(tmpLocalRef);
Expand All @@ -209,13 +230,21 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
D_ASSERT(J_DuckArray_init);
env->DeleteLocalRef(tmpLocalRef);

tmpLocalRef = env->FindClass("org/duckdb/DuckDBStruct");
D_ASSERT(tmpLocalRef);
J_DuckStruct = (jclass)env->NewGlobalRef(tmpLocalRef);
J_DuckStruct = GetClassRef(env, "org/duckdb/DuckDBStruct");
J_DuckStruct_init =
env->GetMethodID(J_DuckStruct, "<init>", "([Ljava/lang/String;[Lorg/duckdb/DuckDBVector;ILjava/lang/String;)V");
D_ASSERT(J_DuckStruct_init);
env->DeleteLocalRef(tmpLocalRef);

J_Struct = GetClassRef(env, "java/sql/Struct");
J_Struct_getSQLTypeName = env->GetMethodID(J_Struct, "getSQLTypeName", "()Ljava/lang/String;");
J_Struct_getAttributes = env->GetMethodID(J_Struct, "getAttributes", "()[Ljava/lang/Object;");

J_Array = GetClassRef(env, "java/sql/Array");
J_Array_getArray = env->GetMethodID(J_Array, "getArray", "()Ljava/lang/Object;");
J_Array_getBaseTypeName = env->GetMethodID(J_Array, "getBaseTypeName", "()Ljava/lang/String;");

J_Object = GetClassRef(env, "java/lang/Object");
J_Object_toString = env->GetMethodID(J_Object, "toString", "()Ljava/lang/String;");

tmpLocalRef = env->FindClass("java/util/Map$Entry");
J_Entry_getKey = env->GetMethodID(tmpLocalRef, "getKey", "()Ljava/lang/Object;");
Expand Down Expand Up @@ -559,11 +588,120 @@ struct ResultHolder {
duckdb::unique_ptr<DataChunk> chunk;
};

Value ToValue(JNIEnv *env, jobject param, duckdb::shared_ptr<ClientContext> context) {
param = env->CallStaticObjectMethod(J_Timestamp, J_Timestamp_valueOf, param);

if (param == nullptr) {
return (Value());
} else if (env->IsInstanceOf(param, J_Bool)) {
return (Value::BOOLEAN(env->CallBooleanMethod(param, J_Bool_booleanValue)));
} else if (env->IsInstanceOf(param, J_Byte)) {
return (Value::TINYINT(env->CallByteMethod(param, J_Byte_byteValue)));
} else if (env->IsInstanceOf(param, J_Short)) {
return (Value::SMALLINT(env->CallShortMethod(param, J_Short_shortValue)));
} else if (env->IsInstanceOf(param, J_Int)) {
return (Value::INTEGER(env->CallIntMethod(param, J_Int_intValue)));
} else if (env->IsInstanceOf(param, J_Long)) {
return (Value::BIGINT(env->CallLongMethod(param, J_Long_longValue)));
} else if (env->IsInstanceOf(param, J_TimestampTZ)) { // Check for subclass before superclass!
return (
Value::TIMESTAMPTZ((timestamp_t)env->CallLongMethod(param, J_TimestampTZ_getMicrosEpoch)));
} else if (env->IsInstanceOf(param, J_DuckDBDate)) {
return (
Value::DATE((date_t)env->CallLongMethod(param, J_DuckDBDate_getDaysSinceEpoch)));

} else if (env->IsInstanceOf(param, J_DuckDBTime)) {
return (Value::TIME((dtime_t)env->CallLongMethod(param, J_Timestamp_getMicrosEpoch)));
} else if (env->IsInstanceOf(param, J_Timestamp)) {
return (
Value::TIMESTAMP((timestamp_t)env->CallLongMethod(param, J_Timestamp_getMicrosEpoch)));
} else if (env->IsInstanceOf(param, J_Float)) {
return (Value::FLOAT(env->CallFloatMethod(param, J_Float_floatValue)));
} else if (env->IsInstanceOf(param, J_Double)) {
return (Value::DOUBLE(env->CallDoubleMethod(param, J_Double_doubleValue)));
} else if (env->IsInstanceOf(param, J_Decimal)) {
Value val = create_value_from_bigdecimal(env, param);
return (val);
} else if (env->IsInstanceOf(param, J_String)) {
auto param_string = jstring_to_string(env, (jstring)param);
return (Value(param_string));
} else if (env->IsInstanceOf(param, J_ByteArray)) {
return (Value::BLOB_RAW(byte_array_to_string(env, (jbyteArray)param)));
} else if (env->IsInstanceOf(param, J_UUID)) {
auto most_significant = (jlong)env->CallObjectMethod(param, J_UUID_getMostSignificantBits);
auto least_significant = (jlong)env->CallObjectMethod(param, J_UUID_getLeastSignificantBits);
return (Value::UUID(hugeint_t(most_significant, least_significant)));
} else if (env->IsInstanceOf(param, J_DuckMap)) {
auto typeName = jstring_to_string(env, (jstring)env->CallObjectMethod(param, J_DuckMap_getSQLTypeName));

LogicalType type;
context->RunFunctionInTransaction([&]() { type = TransformStringToLogicalType(typeName, *context); });

auto entrySet = env->CallObjectMethod(param, J_Map_entrySet);
auto iterator = env->CallObjectMethod(entrySet, J_Set_iterator);
duckdb::vector<Value> entries;
while (env->CallBooleanMethod(iterator, J_Iterator_hasNext)) {
auto entry = env->CallObjectMethod(iterator, J_Iterator_next);

auto key = env->CallObjectMethod(entry, J_Entry_getKey);
auto value = env->CallObjectMethod(entry, J_Entry_getValue);
D_ASSERT(key);
D_ASSERT(value);

entries.push_back(Value::STRUCT({{"key", ToValue(env, key, context)}, {"value", ToValue(env, value, context)}}));
}

return (Value::MAP(ListType::GetChildType(type), entries));

} else if (env->IsInstanceOf(param, J_Struct)) {
auto typeName = jstring_to_string(env, (jstring)env->CallObjectMethod(param, J_Struct_getSQLTypeName));

LogicalType type;
context->RunFunctionInTransaction([&]() { type = TransformStringToLogicalType(typeName, *context); });

auto jvalues = (jobjectArray)env->CallObjectMethod(param, J_Struct_getAttributes);

int size = env->GetArrayLength(jvalues);

child_list_t<Value> values;

for (int i = 0; i < size; i++) {
auto name = StructType::GetChildName(type, i);

auto value = env->GetObjectArrayElement(jvalues, i);

values.emplace_back(name, ToValue(env, value, context));
}

return (Value::STRUCT(std::move(values)));
} else if (env->IsInstanceOf(param, J_Array)) {
auto typeName = jstring_to_string(env, (jstring)env->CallObjectMethod(param, J_Array_getBaseTypeName));
auto jvalues = (jobjectArray)env->CallObjectMethod(param, J_Array_getArray);
int size = env->GetArrayLength(jvalues);

LogicalType type;
context->RunFunctionInTransaction([&]() { type = TransformStringToLogicalType(typeName, *context); });

duckdb::vector<Value> values;
for (int i = 0; i < size; i++) {
auto value = env->GetObjectArrayElement(jvalues, i);

values.emplace_back(ToValue(env, value, context));
}

return (Value::LIST(type, values));

} else {
throw InvalidInputException("Unsupported parameter type");
}
}

jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectArray params) {
auto stmt_ref = (StatementHolder *)env->GetDirectBufferAddress(stmt_ref_buf);
if (!stmt_ref) {
throw InvalidInputException("Invalid statement");
}

auto res_ref = make_uniq<ResultHolder>();
duckdb::vector<Value> duckdb_params;

Expand All @@ -572,64 +710,12 @@ jobject _duckdb_jdbc_execute(JNIEnv *env, jclass, jobject stmt_ref_buf, jobjectA
throw InvalidInputException("Parameter count mismatch");
}

auto &context = stmt_ref->stmt->context;

if (param_len > 0) {
for (idx_t i = 0; i < param_len; i++) {
auto param = env->GetObjectArrayElement(params, i);
if (param == nullptr) {
duckdb_params.push_back(Value());
continue;
} else if (env->IsInstanceOf(param, J_Bool)) {
duckdb_params.push_back(Value::BOOLEAN(env->CallBooleanMethod(param, J_Bool_booleanValue)));
continue;
} else if (env->IsInstanceOf(param, J_Byte)) {
duckdb_params.push_back(Value::TINYINT(env->CallByteMethod(param, J_Byte_byteValue)));
continue;
} else if (env->IsInstanceOf(param, J_Short)) {
duckdb_params.push_back(Value::SMALLINT(env->CallShortMethod(param, J_Short_shortValue)));
continue;
} else if (env->IsInstanceOf(param, J_Int)) {
duckdb_params.push_back(Value::INTEGER(env->CallIntMethod(param, J_Int_intValue)));
continue;
} else if (env->IsInstanceOf(param, J_Long)) {
duckdb_params.push_back(Value::BIGINT(env->CallLongMethod(param, J_Long_longValue)));
continue;
} else if (env->IsInstanceOf(param, J_TimestampTZ)) { // Check for subclass before superclass!
duckdb_params.push_back(
Value::TIMESTAMPTZ((timestamp_t)env->CallLongMethod(param, J_TimestampTZ_getMicrosEpoch)));
continue;
} else if (env->IsInstanceOf(param, J_DuckDBDate)) {
duckdb_params.push_back(
Value::DATE((date_t)env->CallLongMethod(param, J_DuckDBDate_getDaysSinceEpoch)));

} else if (env->IsInstanceOf(param, J_DuckDBTime)) {
duckdb_params.push_back(Value::TIME((dtime_t)env->CallLongMethod(param, J_Timestamp_getMicrosEpoch)));

} else if (env->IsInstanceOf(param, J_Timestamp)) {
duckdb_params.push_back(
Value::TIMESTAMP((timestamp_t)env->CallLongMethod(param, J_Timestamp_getMicrosEpoch)));
continue;
} else if (env->IsInstanceOf(param, J_Float)) {
duckdb_params.push_back(Value::FLOAT(env->CallFloatMethod(param, J_Float_floatValue)));
continue;
} else if (env->IsInstanceOf(param, J_Double)) {
duckdb_params.push_back(Value::DOUBLE(env->CallDoubleMethod(param, J_Double_doubleValue)));
continue;
} else if (env->IsInstanceOf(param, J_Decimal)) {
Value val = create_value_from_bigdecimal(env, param);
duckdb_params.push_back(val);
continue;
} else if (env->IsInstanceOf(param, J_String)) {
auto param_string = jstring_to_string(env, (jstring)param);
duckdb_params.push_back(Value(param_string));
} else if (env->IsInstanceOf(param, J_ByteArray)) {
duckdb_params.push_back(Value::BLOB_RAW(byte_array_to_string(env, (jbyteArray)param)));
} else if (env->IsInstanceOf(param, J_UUID)) {
auto most_significant = (jlong)env->CallObjectMethod(param, J_UUID_getMostSignificantBits);
auto least_significant = (jlong)env->CallObjectMethod(param, J_UUID_getLeastSignificantBits);
duckdb_params.push_back(Value::UUID(hugeint_t(most_significant, least_significant)));
} else {
throw InvalidInputException("Unsupported parameter type");
}
duckdb_params.push_back(ToValue(env, param, context));
}
}

Expand Down
12 changes: 10 additions & 2 deletions src/main/java/org/duckdb/DuckDBConnection.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package org.duckdb;

import org.duckdb.user.DuckDBMap;
import org.duckdb.user.DuckDBUserArray;
import org.duckdb.user.DuckDBUserStruct;

import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
Expand Down Expand Up @@ -327,11 +331,15 @@ public Properties getClientInfo() throws SQLException {
}

public Array createArrayOf(String typeName, Object[] elements) throws SQLException {
throw new SQLFeatureNotSupportedException("createArrayOf");
return new DuckDBUserArray(typeName, elements);
}

public Struct createStruct(String typeName, Object[] attributes) throws SQLException {
throw new SQLFeatureNotSupportedException("createStruct");
return new DuckDBUserStruct(typeName, attributes);
}

public <K, V> Map<K, V> createMap(String typeName, Map<K, V> map) {
return new DuckDBMap<>(typeName, map);
}

public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLException {
Expand Down
12 changes: 0 additions & 12 deletions src/main/java/org/duckdb/DuckDBPreparedStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,6 @@ public void setObject(int parameterIndex, Object x) throws SQLException {
if (params.length == 0) {
params = new Object[getParameterMetaData().getParameterCount()];
}
// Change sql.Timestamp to DuckDBTimestamp
if (x instanceof Timestamp) {
x = new DuckDBTimestamp((Timestamp) x);
} else if (x instanceof LocalDateTime) {
x = new DuckDBTimestamp((LocalDateTime) x);
} else if (x instanceof OffsetDateTime) {
x = new DuckDBTimestampTZ((OffsetDateTime) x);
} else if (x instanceof Date) {
x = new DuckDBDate((Date) x);
} else if (x instanceof Time) {
x = new DuckDBTime((Time) x);
}
params[parameterIndex - 1] = x;
}

Expand Down
19 changes: 19 additions & 0 deletions src/main/java/org/duckdb/DuckDBTimestamp.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.duckdb;

import java.sql.Timestamp;
import java.sql.Time;
import java.sql.Date;
import java.time.ZoneOffset;
import java.time.Instant;
import java.time.LocalDateTime;
Expand Down Expand Up @@ -99,6 +101,23 @@ public static long localDateTime2Micros(LocalDateTime localDateTime) {
return DuckDBTimestamp.RefLocalDateTime.until(localDateTime, ChronoUnit.MICROS);
}

// TODO: move this to C++ side
public static Object valueOf(Object x) {
// Change sql.Timestamp to DuckDBTimestamp
if (x instanceof Timestamp) {
x = new DuckDBTimestamp((Timestamp) x);
} else if (x instanceof LocalDateTime) {
x = new DuckDBTimestamp((LocalDateTime) x);
} else if (x instanceof OffsetDateTime) {
x = new DuckDBTimestampTZ((OffsetDateTime) x);
} else if (x instanceof Date) {
x = new DuckDBDate((Date) x);
} else if (x instanceof Time) {
x = new DuckDBTime((Time) x);
}
return x;
}

public Timestamp toSqlTimestamp() {
return Timestamp.valueOf(this.toLocalDateTime());
}
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/org/duckdb/user/DuckDBMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.duckdb.user;

import java.util.HashMap;
import java.util.Map;

public class DuckDBMap<K, V> extends HashMap<K, V> {
private final String typeName;

public DuckDBMap(String typeName, Map<K, V> map) {
super(map);
this.typeName = typeName;
}

public String getSQLTypeName() {
return typeName;
}
}
Loading
Loading