Skip to content

Commit

Permalink
Fixing Vector Support
Browse files Browse the repository at this point in the history
  • Loading branch information
clun committed Sep 19, 2023
1 parent 1bba649 commit 49ab76f
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 19 deletions.
13 changes: 7 additions & 6 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
<modelVersion>4.0.0</modelVersion>
<groupId>com.datastax.jdbc</groupId>
<artifactId>astra-jdbc-wapper</artifactId>
<version>0.1.0</version>
<version>0.1.1-SNAPSHOT</version>
<name>ing-jdbc-wrapper-shaded</name>

<properties>
<astra-sdk.version>0.6.3</astra-sdk.version>
<cassandra-driver.version>4.16.0</cassandra-driver.version>
<cassandra-jdbc.version>4.9.0</cassandra-jdbc.version>
<java.version>1.8</java.version>
<astra-sdk.version>0.6.11</astra-sdk.version>
<cassandra-driver.version>4.17.0</cassandra-driver.version>
<cassandra-jdbc.version>4.9.1</cassandra-jdbc.version>

<slf4j.version>2.0.7</slf4j.version>
<logback.version>1.4.8</logback.version>
Expand Down Expand Up @@ -76,8 +77,8 @@
<artifactId>maven-compiler-plugin</artifactId>
<version>${maven-compiler-plugin.version}</version>
<configuration>
<source>11</source>
<target>11</target>
<source>${java.version}</source>
<target>${java.version}</target>
<showWarnings>false</showWarnings>
</configuration>
</plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ public class AstraJdbcDataSource implements ConnectionPoolDataSource, javax.sql.
private String consistencyLevel;
private Integer requestTimeout;

public AstraJdbcDataSource(String token, String database, String keyspace) {
this("token", token, database, keyspace);
}

public AstraJdbcDataSource(String user, String password, String database, String keyspace) {
this.user = user;
this.password = password;
Expand Down
21 changes: 8 additions & 13 deletions src/main/java/com/datastax/astra/jdbc/AstraJdbcDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import com.datastax.oss.driver.api.core.CqlSessionBuilder;
import com.dtsx.astra.sdk.db.AstraDbClient;
import com.dtsx.astra.sdk.db.DatabaseClient;
import com.github.benmanes.caffeine.cache.CacheLoader;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.benmanes.caffeine.cache.LoadingCache;
import com.ing.data.cassandra.jdbc.CassandraConnection;
import com.ing.data.cassandra.jdbc.utils.DriverUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -22,8 +22,7 @@
import java.util.List;
import java.util.Properties;

import static com.ing.data.cassandra.jdbc.Utils.getDriverProperty;
import static com.ing.data.cassandra.jdbc.Utils.parseVersion;
import static com.ing.data.cassandra.jdbc.utils.DriverUtil.getDriverProperty;

/**
* This Class would wrap the DataStax Java Driver for Apache Cassandra.
Expand Down Expand Up @@ -51,13 +50,9 @@ public class AstraJdbcDriver implements java.sql.Driver {
* Reuse Session when possible.
*/
final LoadingCache<AstraJdbcUrl, CqlSession > cachedSessions = Caffeine.newBuilder()
.build(new CacheLoader<AstraJdbcUrl, CqlSession>() {
@Override
public CqlSession load(final AstraJdbcUrl jdbcUrl)
throws Exception {
LOGGER.info("Creating a new Session for db '" + jdbcUrl.getDatabaseId() + "'");
return buildSession(jdbcUrl);
}
.build(jdbcUrl -> {
LOGGER.info("Creating a new Session for db '" + jdbcUrl.getDatabaseId() + "'");
return buildSession(jdbcUrl);
});

public static void register() {}
Expand Down Expand Up @@ -102,7 +97,7 @@ public static CqlSession buildSession(AstraJdbcUrl jdbcUrl) {
*/
public Connection connect(String url, Properties properties) throws SQLException {
AstraJdbcUrl jdbcUrl = new AstraJdbcUrl(url, properties);
return new CassandraConnection( this.cachedSessions.get(jdbcUrl),
return new CassandraConnection(this.cachedSessions.get(jdbcUrl),
jdbcUrl.getKeyspace(),
jdbcUrl.getConsistencyLevel(),
jdbcUrl.isDebug(),
Expand Down Expand Up @@ -131,13 +126,13 @@ public boolean acceptsURL(final String url) {
/** {@inheritDoc} */
@Override
public int getMajorVersion() {
return parseVersion(getDriverProperty("driver.version"), 0);
return DriverUtil.parseVersion(getDriverProperty("driver.version"), 0);
}

/** {@inheritDoc} */
@Override
public int getMinorVersion() {
return parseVersion(getDriverProperty("driver.version"), 1);
return DriverUtil.parseVersion(getDriverProperty("driver.version"), 1);
}

/** {@inheritDoc} */
Expand Down
116 changes: 116 additions & 0 deletions src/test/java/com/datastax/astra/jdbc/jdbc/TestVector.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package com.datastax.astra.jdbc.jdbc;

import com.datastax.astra.jdbc.AstraJdbcDataSource;
import com.dtsx.astra.sdk.db.AstraDbClient;
import com.dtsx.astra.sdk.utils.TestUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.Optional;

/**
* This class test Support of Vector with CQL
*/
public class TestVector {

static Logger logger = LoggerFactory.getLogger(TestVector.class);

private static final String TEST_DB = "test_jdbc_wrapper";
private static final String TEST_KEYSPACE = "test";
private static String token;
private static AstraJdbcDataSource jdbcDataSource;

@BeforeAll
public static void setupDb() throws SQLException {
token = Optional.ofNullable(System.getenv("ASTRA_DB_APPLICATION_TOKEN"))
.orElseThrow(() -> new IllegalStateException("Please define env variable ASTRA_DB_APPLICATION_TOKEN"));
logger.info("[setup] - Token found");

// Create DB
TestUtils.setupVectorDatabase(TEST_DB, TEST_KEYSPACE);
logger.info("[setup] - DB Setup");

// Create DataSource
jdbcDataSource = new AstraJdbcDataSource(token, TEST_DB, TEST_KEYSPACE);
logger.info("[setup] - Jdbc Connection established");

// Create Tables and data
createVectorTable();
logger.info("[setup] - Schema Created");
}

@Test
public void testSimilaritySearch() {
try (Connection connection = jdbcDataSource.getConnection()) {
PreparedStatement ps = connection.prepareStatement("" +
"SELECT\n" +
" product_id, product_vector,\n" +
" similarity_dot_product(product_vector,[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]) as similarity\n" +
"FROM pet_supply_vectors\n" +
"ORDER BY product_vector\n" +
"ANN OF [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n" +
"LIMIT 2;");
java.sql.ResultSet rs = ps.executeQuery();
// A result has been found
Assertions.assertTrue(rs.next());
logger.info("Similarity Search succeed {}", rs.getObject("product_vector"));

} catch (java.sql.SQLException e) {
System.out.println("[KO] " + e.getMessage());
}
}

/**
* Create table with JDBC
*
* @throws SQLException
* error
*/
private static void createVectorTable() throws SQLException {
// Create Connection
try (Connection conn = jdbcDataSource.getConnection()) {

// Create a Table with Embeddings
conn.createStatement().execute("" +
"CREATE TABLE IF NOT EXISTS pet_supply_vectors (" +
" product_id TEXT PRIMARY KEY," +
" product_name TEXT," +
" product_vector vector<float, 14>)");
logger.info("Table created.");

// Create a Search Index
conn.createStatement().execute("" +
"CREATE CUSTOM INDEX IF NOT EXISTS idx_vector " +
"ON pet_supply_vectors(product_vector) " +
"USING 'StorageAttachedIndex'");
logger.info("Index Created.");

// Insert rows
conn.createStatement().execute("" +
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pf1843','HealthyFresh - Chicken raw dog food',[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])");
conn.createStatement().execute("" +
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pf1844','HealthyFresh - Beef raw dog food',[1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0])");
conn.createStatement().execute("" +
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pt0021','Dog Tennis Ball Toy',[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0])");
conn.createStatement().execute("" +
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pt0041','Dog Ring Chew Toy',[0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0])");
conn.createStatement().execute("" +
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pf7043','PupperSausage Bacon dog Treats',[0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1])");
conn.createStatement().execute("" +
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
"VALUES ('pf7044','PupperSausage Beef dog Treats',[0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0])");
logger.info("Table populated.");
}
}
}

0 comments on commit 49ab76f

Please sign in to comment.